Format all files

This commit is contained in:
Owen
2025-12-09 10:56:14 -05:00
parent fa839a811f
commit f9b03943c3
535 changed files with 7670 additions and 5626 deletions

View File

@@ -1,6 +1,3 @@
{
"extends": [
"next/core-web-vitals",
"next/typescript"
]
"extends": ["next/core-web-vitals", "next/typescript"]
}

View File

@@ -1,9 +1,7 @@
import { defineConfig } from "drizzle-kit";
import path from "path";
const schema = [
path.join("server", "db", "pg", "schema"),
];
const schema = [path.join("server", "db", "pg", "schema")];
export default defineConfig({
dialect: "postgresql",

View File

@@ -2,9 +2,7 @@ import { APP_PATH } from "@server/lib/consts";
import { defineConfig } from "drizzle-kit";
import path from "path";
const schema = [
path.join("server", "db", "sqlite", "schema"),
];
const schema = [path.join("server", "db", "sqlite", "schema")];
export default defineConfig({
dialect: "sqlite",

View File

@@ -24,20 +24,20 @@ const argv = yargs(hideBin(process.argv))
alias: "e",
describe: "Entry point file",
type: "string",
demandOption: true,
demandOption: true
})
.option("out", {
alias: "o",
describe: "Output file path",
type: "string",
demandOption: true,
demandOption: true
})
.option("build", {
alias: "b",
describe: "Build type (oss, saas, enterprise)",
type: "string",
choices: ["oss", "saas", "enterprise"],
default: "oss",
default: "oss"
})
.help()
.alias("help", "h").argv;
@@ -66,7 +66,9 @@ function privateImportGuardPlugin() {
// Check if the importing file is NOT in server/private
const normalizedImporter = path.normalize(importingFile);
const isInServerPrivate = normalizedImporter.includes(path.normalize("server/private"));
const isInServerPrivate = normalizedImporter.includes(
path.normalize("server/private")
);
if (!isInServerPrivate) {
const violation = {
@@ -79,8 +81,8 @@ function privateImportGuardPlugin() {
console.log(`PRIVATE IMPORT VIOLATION:`);
console.log(` File: ${importingFile}`);
console.log(` Import: ${args.path}`);
console.log(` Resolve dir: ${args.resolveDir || 'N/A'}`);
console.log('');
console.log(` Resolve dir: ${args.resolveDir || "N/A"}`);
console.log("");
}
// Return null to let the default resolver handle it
@@ -89,16 +91,20 @@ function privateImportGuardPlugin() {
build.onEnd((result) => {
if (violations.length > 0) {
console.log(`\nSUMMARY: Found ${violations.length} private import violation(s):`);
console.log(
`\nSUMMARY: Found ${violations.length} private import violation(s):`
);
violations.forEach((v, i) => {
console.log(` ${i + 1}. ${path.relative(process.cwd(), v.file)} imports ${v.importPath}`);
console.log(
` ${i + 1}. ${path.relative(process.cwd(), v.file)} imports ${v.importPath}`
);
});
console.log('');
console.log("");
result.errors.push({
text: `Private import violations detected: ${violations.length} violation(s) found`,
location: null,
notes: violations.map(v => ({
notes: violations.map((v) => ({
text: `${path.relative(process.cwd(), v.file)} imports ${v.importPath}`,
location: null
}))
@@ -121,7 +127,9 @@ function dynamicImportGuardPlugin() {
// Check if the importing file is NOT in server/private
const normalizedImporter = path.normalize(importingFile);
const isInServerPrivate = normalizedImporter.includes(path.normalize("server/private"));
const isInServerPrivate = normalizedImporter.includes(
path.normalize("server/private")
);
if (isInServerPrivate) {
const violation = {
@@ -134,8 +142,8 @@ function dynamicImportGuardPlugin() {
console.log(`DYNAMIC IMPORT VIOLATION:`);
console.log(` File: ${importingFile}`);
console.log(` Import: ${args.path}`);
console.log(` Resolve dir: ${args.resolveDir || 'N/A'}`);
console.log('');
console.log(` Resolve dir: ${args.resolveDir || "N/A"}`);
console.log("");
}
// Return null to let the default resolver handle it
@@ -144,16 +152,20 @@ function dynamicImportGuardPlugin() {
build.onEnd((result) => {
if (violations.length > 0) {
console.log(`\nSUMMARY: Found ${violations.length} dynamic import violation(s):`);
console.log(
`\nSUMMARY: Found ${violations.length} dynamic import violation(s):`
);
violations.forEach((v, i) => {
console.log(` ${i + 1}. ${path.relative(process.cwd(), v.file)} imports ${v.importPath}`);
console.log(
` ${i + 1}. ${path.relative(process.cwd(), v.file)} imports ${v.importPath}`
);
});
console.log('');
console.log("");
result.errors.push({
text: `Dynamic import violations detected: ${violations.length} violation(s) found`,
location: null,
notes: violations.map(v => ({
notes: violations.map((v) => ({
text: `${path.relative(process.cwd(), v.file)} imports ${v.importPath}`,
location: null
}))
@@ -172,21 +184,28 @@ function dynamicImportSwitcherPlugin(buildValue) {
const switches = [];
build.onStart(() => {
console.log(`Dynamic import switcher using build type: ${buildValue}`);
console.log(
`Dynamic import switcher using build type: ${buildValue}`
);
});
build.onResolve({ filter: /^#dynamic\// }, (args) => {
// Extract the path after #dynamic/
const dynamicPath = args.path.replace(/^#dynamic\//, '');
const dynamicPath = args.path.replace(/^#dynamic\//, "");
// Determine the replacement based on build type
let replacement;
if (buildValue === "oss") {
replacement = `#open/${dynamicPath}`;
} else if (buildValue === "saas" || buildValue === "enterprise") {
} else if (
buildValue === "saas" ||
buildValue === "enterprise"
) {
replacement = `#closed/${dynamicPath}`; // We use #closed here so that the route guards dont complain after its been changed but this is the same as #private
} else {
console.warn(`Unknown build type '${buildValue}', defaulting to #open/`);
console.warn(
`Unknown build type '${buildValue}', defaulting to #open/`
);
replacement = `#open/${dynamicPath}`;
}
@@ -201,8 +220,10 @@ function dynamicImportSwitcherPlugin(buildValue) {
console.log(`DYNAMIC IMPORT SWITCH:`);
console.log(` File: ${args.importer}`);
console.log(` Original: ${args.path}`);
console.log(` Switched to: ${replacement} (build: ${buildValue})`);
console.log('');
console.log(
` Switched to: ${replacement} (build: ${buildValue})`
);
console.log("");
// Rewrite the import path and let the normal resolution continue
return build.resolve(replacement, {
@@ -215,12 +236,18 @@ function dynamicImportSwitcherPlugin(buildValue) {
build.onEnd((result) => {
if (switches.length > 0) {
console.log(`\nDYNAMIC IMPORT SUMMARY: Switched ${switches.length} import(s) for build type '${buildValue}':`);
console.log(
`\nDYNAMIC IMPORT SUMMARY: Switched ${switches.length} import(s) for build type '${buildValue}':`
);
switches.forEach((s, i) => {
console.log(` ${i + 1}. ${path.relative(process.cwd(), s.file)}`);
console.log(` ${s.originalPath} ${s.replacementPath}`);
console.log(
` ${i + 1}. ${path.relative(process.cwd(), s.file)}`
);
console.log(
` ${s.originalPath}${s.replacementPath}`
);
});
console.log('');
console.log("");
}
});
}
@@ -235,7 +262,7 @@ esbuild
format: "esm",
minify: false,
banner: {
js: banner,
js: banner
},
platform: "node",
external: ["body-parser"],
@@ -244,20 +271,22 @@ esbuild
dynamicImportGuardPlugin(),
dynamicImportSwitcherPlugin(argv.build),
nodeExternalsPlugin({
packagePath: getPackagePaths(),
}),
packagePath: getPackagePaths()
})
],
sourcemap: "inline",
target: "node22",
target: "node22"
})
.then((result) => {
// Check if there were any errors in the build result
if (result.errors && result.errors.length > 0) {
console.error(`Build failed with ${result.errors.length} error(s):`);
console.error(
`Build failed with ${result.errors.length} error(s):`
);
result.errors.forEach((error, i) => {
console.error(`${i + 1}. ${error.text}`);
if (error.notes) {
error.notes.forEach(note => {
error.notes.forEach((note) => {
console.error(` - ${note.text}`);
});
}

View File

@@ -1,4 +1,4 @@
import tseslint from 'typescript-eslint';
import tseslint from "typescript-eslint";
export default tseslint.config({
files: ["**/*.{ts,tsx,js,jsx}"],
@@ -13,7 +13,7 @@ export default tseslint.config({
}
},
rules: {
"semi": "error",
semi: "error",
"prefer-const": "warn"
}
});

View File

@@ -1,8 +1,8 @@
/** @type {import('postcss-load-config').Config} */
const config = {
plugins: {
"@tailwindcss/postcss": {},
},
"@tailwindcss/postcss": {}
}
};
export default config;

View File

@@ -2,13 +2,13 @@ import { hash, verify } from "@node-rs/argon2";
export async function verifyPassword(
password: string,
hash: string,
hash: string
): Promise<boolean> {
const validPassword = await verify(hash, password, {
memoryCost: 19456,
timeCost: 2,
outputLen: 32,
parallelism: 1,
parallelism: 1
});
return validPassword;
}
@@ -18,7 +18,7 @@ export async function hashPassword(password: string): Promise<string> {
memoryCost: 19456,
timeCost: 2,
outputLen: 32,
parallelism: 1,
parallelism: 1
});
return passwordHash;

View File

@@ -4,10 +4,13 @@ export const passwordSchema = z
.string()
.min(8, { message: "Password must be at least 8 characters long" })
.max(128, { message: "Password must be at most 128 characters long" })
.regex(/^(?=.*?[A-Z])(?=.*?[a-z])(?=.*?[0-9])(?=.*?[~!`@#$%^&*()_\-+={}[\]|\\:;"'<>,.\/?]).*$/, {
.regex(
/^(?=.*?[A-Z])(?=.*?[a-z])(?=.*?[0-9])(?=.*?[~!`@#$%^&*()_\-+={}[\]|\\:;"'<>,.\/?]).*$/,
{
message: `Your password must meet the following conditions:
at least one uppercase English letter,
at least one lowercase English letter,
at least one digit,
at least one special character.`
});
}
);

View File

@@ -1,6 +1,4 @@
import {
encodeHexLowerCase,
} from "@oslojs/encoding";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { Newt, newts, newtSessions, NewtSession } from "@server/db";
import { db } from "@server/db";
@@ -10,25 +8,25 @@ export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
export async function createNewtSession(
token: string,
newtId: string,
newtId: string
): Promise<NewtSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const session: NewtSession = {
sessionId: sessionId,
newtId,
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
expiresAt: new Date(Date.now() + EXPIRES).getTime()
};
await db.insert(newtSessions).values(session);
return session;
}
export async function validateNewtSessionToken(
token: string,
token: string
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const result = await db
.select({ newt: newts, session: newtSessions })
@@ -45,14 +43,12 @@ export async function validateNewtSessionToken(
.where(eq(newtSessions.sessionId, session.sessionId));
return { session: null, newt: null };
}
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + EXPIRES,
).getTime();
if (Date.now() >= session.expiresAt - EXPIRES / 2) {
session.expiresAt = new Date(Date.now() + EXPIRES).getTime();
await db
.update(newtSessions)
.set({
expiresAt: session.expiresAt,
expiresAt: session.expiresAt
})
.where(eq(newtSessions.sessionId, session.sessionId));
}

View File

@@ -1,6 +1,4 @@
import {
encodeHexLowerCase,
} from "@oslojs/encoding";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { Olm, olms, olmSessions, OlmSession } from "@server/db";
import { db } from "@server/db";
@@ -10,25 +8,25 @@ export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
export async function createOlmSession(
token: string,
olmId: string,
olmId: string
): Promise<OlmSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const session: OlmSession = {
sessionId: sessionId,
olmId,
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
expiresAt: new Date(Date.now() + EXPIRES).getTime()
};
await db.insert(olmSessions).values(session);
return session;
}
export async function validateOlmSessionToken(
token: string,
token: string
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const result = await db
.select({ olm: olms, session: olmSessions })
@@ -45,14 +43,12 @@ export async function validateOlmSessionToken(
.where(eq(olmSessions.sessionId, session.sessionId));
return { session: null, olm: null };
}
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + EXPIRES,
).getTime();
if (Date.now() >= session.expiresAt - EXPIRES / 2) {
session.expiresAt = new Date(Date.now() + EXPIRES).getTime();
await db
.update(olmSessions)
.set({
expiresAt: session.expiresAt,
expiresAt: session.expiresAt
})
.where(eq(olmSessions.sessionId, session.sessionId));
}

File diff suppressed because it is too large Load Diff

View File

@@ -215,7 +215,9 @@ export const sessionTransferToken = pgTable("sessionTransferToken", {
expiresAt: bigint("expiresAt", { mode: "number" }).notNull()
});
export const actionAuditLog = pgTable("actionAuditLog", {
export const actionAuditLog = pgTable(
"actionAuditLog",
{
id: serial("id").primaryKey(),
timestamp: bigint("timestamp", { mode: "number" }).notNull(), // this is EPOCH time in seconds
orgId: varchar("orgId")
@@ -226,12 +228,19 @@ export const actionAuditLog = pgTable("actionAuditLog", {
actorId: varchar("actorId", { length: 255 }).notNull(),
action: varchar("action", { length: 100 }).notNull(),
metadata: text("metadata")
}, (table) => ([
},
(table) => [
index("idx_actionAuditLog_timestamp").on(table.timestamp),
index("idx_actionAuditLog_org_timestamp").on(table.orgId, table.timestamp)
]));
index("idx_actionAuditLog_org_timestamp").on(
table.orgId,
table.timestamp
)
]
);
export const accessAuditLog = pgTable("accessAuditLog", {
export const accessAuditLog = pgTable(
"accessAuditLog",
{
id: serial("id").primaryKey(),
timestamp: bigint("timestamp", { mode: "number" }).notNull(), // this is EPOCH time in seconds
orgId: varchar("orgId")
@@ -247,10 +256,15 @@ export const accessAuditLog = pgTable("accessAuditLog", {
location: text("location"),
userAgent: text("userAgent"),
metadata: text("metadata")
}, (table) => ([
},
(table) => [
index("idx_identityAuditLog_timestamp").on(table.timestamp),
index("idx_identityAuditLog_org_timestamp").on(table.orgId, table.timestamp)
]));
index("idx_identityAuditLog_org_timestamp").on(
table.orgId,
table.timestamp
)
]
);
export type Limit = InferSelectModel<typeof limits>;
export type Account = InferSelectModel<typeof account>;

View File

@@ -177,7 +177,7 @@ export const targetHealthCheck = pgTable("targetHealthCheck", {
hcMethod: varchar("hcMethod").default("GET"),
hcStatus: integer("hcStatus"), // http code
hcHealth: text("hcHealth").default("unknown"), // "unknown", "healthy", "unhealthy"
hcTlsServerName: text("hcTlsServerName"),
hcTlsServerName: text("hcTlsServerName")
});
export const exitNodes = pgTable("exitNodes", {

View File

@@ -52,10 +52,7 @@ export async function getResourceByDomain(
resourceHeaderAuth,
eq(resourceHeaderAuth.resourceId, resources.resourceId)
)
.innerJoin(
orgs,
eq(orgs.orgId, resources.orgId)
)
.innerJoin(orgs, eq(orgs.orgId, resources.orgId))
.where(eq(resources.fullDomain, domain))
.limit(1);

View File

@@ -8,7 +8,7 @@ const runMigrations = async () => {
console.log("Running migrations...");
try {
migrate(db as any, {
migrationsFolder: migrationsFolder,
migrationsFolder: migrationsFolder
});
console.log("Migrations completed successfully.");
} catch (error) {

View File

@@ -29,7 +29,9 @@ export const certificates = sqliteTable("certificates", {
});
export const dnsChallenge = sqliteTable("dnsChallenges", {
dnsChallengeId: integer("dnsChallengeId").primaryKey({ autoIncrement: true }),
dnsChallengeId: integer("dnsChallengeId").primaryKey({
autoIncrement: true
}),
domain: text("domain").notNull(),
token: text("token").notNull(),
keyAuthorization: text("keyAuthorization").notNull(),
@@ -61,9 +63,7 @@ export const customers = sqliteTable("customers", {
});
export const subscriptions = sqliteTable("subscriptions", {
subscriptionId: text("subscriptionId")
.primaryKey()
.notNull(),
subscriptionId: text("subscriptionId").primaryKey().notNull(),
customerId: text("customerId")
.notNull()
.references(() => customers.customerId, { onDelete: "cascade" }),
@@ -75,7 +75,9 @@ export const subscriptions = sqliteTable("subscriptions", {
});
export const subscriptionItems = sqliteTable("subscriptionItems", {
subscriptionItemId: integer("subscriptionItemId").primaryKey({ autoIncrement: true }),
subscriptionItemId: integer("subscriptionItemId").primaryKey({
autoIncrement: true
}),
subscriptionId: text("subscriptionId")
.notNull()
.references(() => subscriptions.subscriptionId, {
@@ -129,7 +131,9 @@ export const limits = sqliteTable("limits", {
});
export const usageNotifications = sqliteTable("usageNotifications", {
notificationId: integer("notificationId").primaryKey({ autoIncrement: true }),
notificationId: integer("notificationId").primaryKey({
autoIncrement: true
}),
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
@@ -210,7 +214,9 @@ export const sessionTransferToken = sqliteTable("sessionTransferToken", {
expiresAt: integer("expiresAt").notNull()
});
export const actionAuditLog = sqliteTable("actionAuditLog", {
export const actionAuditLog = sqliteTable(
"actionAuditLog",
{
id: integer("id").primaryKey({ autoIncrement: true }),
timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds
orgId: text("orgId")
@@ -221,12 +227,19 @@ export const actionAuditLog = sqliteTable("actionAuditLog", {
actorId: text("actorId").notNull(),
action: text("action").notNull(),
metadata: text("metadata")
}, (table) => ([
},
(table) => [
index("idx_actionAuditLog_timestamp").on(table.timestamp),
index("idx_actionAuditLog_org_timestamp").on(table.orgId, table.timestamp)
]));
index("idx_actionAuditLog_org_timestamp").on(
table.orgId,
table.timestamp
)
]
);
export const accessAuditLog = sqliteTable("accessAuditLog", {
export const accessAuditLog = sqliteTable(
"accessAuditLog",
{
id: integer("id").primaryKey({ autoIncrement: true }),
timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds
orgId: text("orgId")
@@ -242,10 +255,15 @@ export const accessAuditLog = sqliteTable("accessAuditLog", {
action: integer("action", { mode: "boolean" }).notNull(),
userAgent: text("userAgent"),
metadata: text("metadata")
}, (table) => ([
},
(table) => [
index("idx_identityAuditLog_timestamp").on(table.timestamp),
index("idx_identityAuditLog_org_timestamp").on(table.orgId, table.timestamp)
]));
index("idx_identityAuditLog_org_timestamp").on(
table.orgId,
table.timestamp
)
]
);
export type Limit = InferSelectModel<typeof limits>;
export type Account = InferSelectModel<typeof account>;

View File

@@ -18,10 +18,13 @@ function createEmailClient() {
host: emailConfig.smtp_host,
port: emailConfig.smtp_port,
secure: emailConfig.smtp_secure || false,
auth: (emailConfig.smtp_user && emailConfig.smtp_pass) ? {
auth:
emailConfig.smtp_user && emailConfig.smtp_pass
? {
user: emailConfig.smtp_user,
pass: emailConfig.smtp_pass
} : null
}
: null
} as SMTPTransport.Options;
if (emailConfig.smtp_tls_reject_unauthorized !== undefined) {

View File

@@ -19,7 +19,13 @@ interface Props {
billingLink: string; // Link to billing page
}
export const NotifyUsageLimitApproaching = ({ email, limitName, currentUsage, usageLimit, billingLink }: Props) => {
export const NotifyUsageLimitApproaching = ({
email,
limitName,
currentUsage,
usageLimit,
billingLink
}: Props) => {
const previewText = `Your usage for ${limitName} is approaching the limit.`;
const usagePercentage = Math.round((currentUsage / usageLimit) * 100);
@@ -37,23 +43,32 @@ export const NotifyUsageLimitApproaching = ({ email, limitName, currentUsage, us
<EmailGreeting>Hi there,</EmailGreeting>
<EmailText>
We wanted to let you know that your usage for <strong>{limitName}</strong> is approaching your plan limit.
We wanted to let you know that your usage for{" "}
<strong>{limitName}</strong> is approaching your
plan limit.
</EmailText>
<EmailText>
<strong>Current Usage:</strong> {currentUsage} of {usageLimit} ({usagePercentage}%)
<strong>Current Usage:</strong> {currentUsage} of{" "}
{usageLimit} ({usagePercentage}%)
</EmailText>
<EmailText>
Once you reach your limit, some functionality may be restricted or your sites may disconnect until you upgrade your plan or your usage resets.
Once you reach your limit, some functionality may be
restricted or your sites may disconnect until you
upgrade your plan or your usage resets.
</EmailText>
<EmailText>
To avoid any interruption to your service, we recommend upgrading your plan or monitoring your usage closely. You can <a href={billingLink}>upgrade your plan here</a>.
To avoid any interruption to your service, we
recommend upgrading your plan or monitoring your
usage closely. You can{" "}
<a href={billingLink}>upgrade your plan here</a>.
</EmailText>
<EmailText>
If you have any questions or need assistance, please don't hesitate to reach out to our support team.
If you have any questions or need assistance, please
don't hesitate to reach out to our support team.
</EmailText>
<EmailFooter>

View File

@@ -19,7 +19,13 @@ interface Props {
billingLink: string; // Link to billing page
}
export const NotifyUsageLimitReached = ({ email, limitName, currentUsage, usageLimit, billingLink }: Props) => {
export const NotifyUsageLimitReached = ({
email,
limitName,
currentUsage,
usageLimit,
billingLink
}: Props) => {
const previewText = `You've reached your ${limitName} usage limit - Action required`;
const usagePercentage = Math.round((currentUsage / usageLimit) * 100);
@@ -32,30 +38,48 @@ export const NotifyUsageLimitReached = ({ email, limitName, currentUsage, usageL
<EmailContainer>
<EmailLetterHead />
<EmailHeading>Usage Limit Reached - Action Required</EmailHeading>
<EmailHeading>
Usage Limit Reached - Action Required
</EmailHeading>
<EmailGreeting>Hi there,</EmailGreeting>
<EmailText>
You have reached your usage limit for <strong>{limitName}</strong>.
You have reached your usage limit for{" "}
<strong>{limitName}</strong>.
</EmailText>
<EmailText>
<strong>Current Usage:</strong> {currentUsage} of {usageLimit} ({usagePercentage}%)
<strong>Current Usage:</strong> {currentUsage} of{" "}
{usageLimit} ({usagePercentage}%)
</EmailText>
<EmailText>
<strong>Important:</strong> Your functionality may now be restricted and your sites may disconnect until you either upgrade your plan or your usage resets. To prevent any service interruption, immediate action is recommended.
<strong>Important:</strong> Your functionality may
now be restricted and your sites may disconnect
until you either upgrade your plan or your usage
resets. To prevent any service interruption,
immediate action is recommended.
</EmailText>
<EmailText>
<strong>What you can do:</strong>
<br /> <a href={billingLink} style={{ color: '#2563eb', fontWeight: 'bold' }}>Upgrade your plan immediately</a> to restore full functionality
<br /> Monitor your usage to stay within limits in the future
<br />{" "}
<a
href={billingLink}
style={{ color: "#2563eb", fontWeight: "bold" }}
>
Upgrade your plan immediately
</a>{" "}
to restore full functionality
<br /> Monitor your usage to stay within limits in
the future
</EmailText>
<EmailText>
If you have any questions or need immediate assistance, please contact our support team right away.
If you have any questions or need immediate
assistance, please contact our support team right
away.
</EmailText>
<EmailFooter>

View File

@@ -5,7 +5,7 @@ import config from "@server/lib/config";
import logger from "@server/logger";
import {
errorHandlerMiddleware,
notFoundMiddleware,
notFoundMiddleware
} from "@server/middlewares";
import { authenticated, unauthenticated } from "#dynamic/routers/integration";
import { logIncomingMiddleware } from "./middlewares/logIncoming";

View File

@@ -25,16 +25,22 @@ export const FeatureMeterIdsSandbox: Record<FeatureId, string> = {
};
export function getFeatureMeterId(featureId: FeatureId): string {
if (process.env.ENVIRONMENT == "prod" && process.env.SANDBOX_MODE !== "true") {
if (
process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true"
) {
return FeatureMeterIds[featureId];
} else {
return FeatureMeterIdsSandbox[featureId];
}
}
export function getFeatureIdByMetricId(metricId: string): FeatureId | undefined {
return (Object.entries(FeatureMeterIds) as [FeatureId, string][])
.find(([_, v]) => v === metricId)?.[0];
export function getFeatureIdByMetricId(
metricId: string
): FeatureId | undefined {
return (Object.entries(FeatureMeterIds) as [FeatureId, string][]).find(
([_, v]) => v === metricId
)?.[0];
}
export type FeaturePriceSet = {
@@ -43,7 +49,8 @@ export type FeaturePriceSet = {
[FeatureId.DOMAINS]?: string; // Optional since domains are not billed
};
export const standardFeaturePriceSet: FeaturePriceSet = { // Free tier matches the freeLimitSet
export const standardFeaturePriceSet: FeaturePriceSet = {
// Free tier matches the freeLimitSet
[FeatureId.SITE_UPTIME]: "price_1RrQc4D3Ee2Ir7WmaJGZ3MtF",
[FeatureId.USERS]: "price_1RrQeJD3Ee2Ir7WmgveP3xea",
[FeatureId.EGRESS_DATA_MB]: "price_1RrQXFD3Ee2Ir7WmvGDlgxQk",
@@ -51,7 +58,8 @@ export const standardFeaturePriceSet: FeaturePriceSet = { // Free tier matches t
[FeatureId.REMOTE_EXIT_NODES]: "price_1S46weD3Ee2Ir7Wm94KEHI4h"
};
export const standardFeaturePriceSetSandbox: FeaturePriceSet = { // Free tier matches the freeLimitSet
export const standardFeaturePriceSetSandbox: FeaturePriceSet = {
// Free tier matches the freeLimitSet
[FeatureId.SITE_UPTIME]: "price_1RefFBDCpkOb237BPrKZ8IEU",
[FeatureId.USERS]: "price_1ReNa4DCpkOb237Bc67G5muF",
[FeatureId.EGRESS_DATA_MB]: "price_1Rfp9LDCpkOb237BwuN5Oiu0",
@@ -60,15 +68,20 @@ export const standardFeaturePriceSetSandbox: FeaturePriceSet = { // Free tier ma
};
export function getStandardFeaturePriceSet(): FeaturePriceSet {
if (process.env.ENVIRONMENT == "prod" && process.env.SANDBOX_MODE !== "true") {
if (
process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true"
) {
return standardFeaturePriceSet;
} else {
return standardFeaturePriceSetSandbox;
}
}
export function getLineItems(featurePriceSet: FeaturePriceSet): Stripe.Checkout.SessionCreateParams.LineItem[] {
export function getLineItems(
featurePriceSet: FeaturePriceSet
): Stripe.Checkout.SessionCreateParams.LineItem[] {
return Object.entries(featurePriceSet).map(([featureId, priceId]) => ({
price: priceId,
price: priceId
}));
}

View File

@@ -12,7 +12,7 @@ export const sandboxLimitSet: LimitSet = {
[FeatureId.USERS]: { value: 1, description: "Sandbox limit" },
[FeatureId.EGRESS_DATA_MB]: { value: 1000, description: "Sandbox limit" }, // 1 GB
[FeatureId.DOMAINS]: { value: 0, description: "Sandbox limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" }
};
export const freeLimitSet: LimitSet = {
@@ -29,7 +29,7 @@ export const freeLimitSet: LimitSet = {
export const subscribedLimitSet: LimitSet = {
[FeatureId.SITE_UPTIME]: {
value: 2232000,
description: "Contact us to increase soft limit.",
description: "Contact us to increase soft limit."
}, // 50 sites up for 31 days
[FeatureId.USERS]: {
value: 150,

View File

@@ -1,22 +1,32 @@
export enum TierId {
STANDARD = "standard",
STANDARD = "standard"
}
export type TierPriceSet = {
[key in TierId]: string;
};
export const tierPriceSet: TierPriceSet = { // Free tier matches the freeLimitSet
[TierId.STANDARD]: "price_1RrQ9cD3Ee2Ir7Wmqdy3KBa0",
export const tierPriceSet: TierPriceSet = {
// Free tier matches the freeLimitSet
[TierId.STANDARD]: "price_1RrQ9cD3Ee2Ir7Wmqdy3KBa0"
};
export const tierPriceSetSandbox: TierPriceSet = { // Free tier matches the freeLimitSet
export const tierPriceSetSandbox: TierPriceSet = {
// Free tier matches the freeLimitSet
// when matching tier the keys closer to 0 index are matched first so list the tiers in descending order of value
[TierId.STANDARD]: "price_1RrAYJDCpkOb237By2s1P32m",
[TierId.STANDARD]: "price_1RrAYJDCpkOb237By2s1P32m"
};
export function getTierPriceSet(environment?: string, sandbox_mode?: boolean): TierPriceSet {
if ((process.env.ENVIRONMENT == "prod" && process.env.SANDBOX_MODE !== "true") || (environment === "prod" && sandbox_mode !== true)) { // THIS GETS LOADED CLIENT SIDE AND SERVER SIDE
export function getTierPriceSet(
environment?: string,
sandbox_mode?: boolean
): TierPriceSet {
if (
(process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true") ||
(environment === "prod" && sandbox_mode !== true)
) {
// THIS GETS LOADED CLIENT SIDE AND SERVER SIDE
return tierPriceSet;
} else {
return tierPriceSetSandbox;

View File

@@ -34,7 +34,10 @@ export async function applyNewtDockerBlueprint(
return;
}
if (isEmptyObject(blueprint["proxy-resources"]) && isEmptyObject(blueprint["client-resources"])) {
if (
isEmptyObject(blueprint["proxy-resources"]) &&
isEmptyObject(blueprint["client-resources"])
) {
return;
}

View File

@@ -84,12 +84,20 @@ export function processContainerLabels(containers: Container[]): {
// Process proxy resources
if (Object.keys(proxyResourceLabels).length > 0) {
processResourceLabels(proxyResourceLabels, container, result["proxy-resources"]);
processResourceLabels(
proxyResourceLabels,
container,
result["proxy-resources"]
);
}
// Process client resources
if (Object.keys(clientResourceLabels).length > 0) {
processResourceLabels(clientResourceLabels, container, result["client-resources"]);
processResourceLabels(
clientResourceLabels,
container,
result["client-resources"]
);
}
});
@@ -161,8 +169,7 @@ function processResourceLabels(
const finalTarget = { ...target };
if (!finalTarget.hostname) {
finalTarget.hostname =
container.name ||
container.hostname;
container.name || container.hostname;
}
if (!finalTarget.port) {
const containerPort =

View File

@@ -324,7 +324,10 @@ export const ConfigSchema = z
return data as {
"proxy-resources": Record<string, z.infer<typeof ResourceSchema>>;
"client-resources": Record<string, z.infer<typeof ClientResourceSchema>>;
"client-resources": Record<
string,
z.infer<typeof ClientResourceSchema>
>;
sites: Record<string, z.infer<typeof SiteSchema>>;
};
})

View File

@@ -166,7 +166,10 @@ export async function calculateUserClientsForOrgs(
];
// Get next available subnet
const newSubnet = await getNextAvailableClientSubnet(orgId, transaction);
const newSubnet = await getNextAvailableClientSubnet(
orgId,
transaction
);
if (!newSubnet) {
logger.warn(
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): no available subnet found`

View File

@@ -1,4 +1,6 @@
export async function getValidCertificatesForDomains(domains: Set<string>): Promise<
export async function getValidCertificatesForDomains(
domains: Set<string>
): Promise<
Array<{
id: number;
domain: string;

View File

@@ -7,7 +7,10 @@ function dateToTimestamp(dateStr: string): number {
// Testable version of calculateCutoffTimestamp that accepts a "now" timestamp
// This matches the logic in cleanupLogs.ts but allows injecting the current time
function calculateCutoffTimestampWithNow(retentionDays: number, nowTimestamp: number): number {
function calculateCutoffTimestampWithNow(
retentionDays: number,
nowTimestamp: number
): number {
if (retentionDays === 9001) {
// Special case: data is erased at the end of the year following the year it was generated
// This means we delete logs from 2 years ago or older (logs from year Y are deleted after Dec 31 of year Y+1)
@@ -28,7 +31,7 @@ function testCalculateCutoffTimestamp() {
{
const now = dateToTimestamp("2025-12-06T12:00:00Z");
const result = calculateCutoffTimestampWithNow(30, now);
const expected = now - (30 * 24 * 60 * 60);
const expected = now - 30 * 24 * 60 * 60;
assertEquals(result, expected, "30 days retention calculation failed");
}
@@ -36,7 +39,7 @@ function testCalculateCutoffTimestamp() {
{
const now = dateToTimestamp("2025-06-15T00:00:00Z");
const result = calculateCutoffTimestampWithNow(90, now);
const expected = now - (90 * 24 * 60 * 60);
const expected = now - 90 * 24 * 60 * 60;
assertEquals(result, expected, "90 days retention calculation failed");
}
@@ -48,7 +51,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2025-12-06T12:00:00Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2024-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (Dec 2025) - should cutoff at Jan 1, 2024");
assertEquals(
result,
expected,
"9001 retention (Dec 2025) - should cutoff at Jan 1, 2024"
);
}
// Test 4: Special case 9001 - January 2026
@@ -58,7 +65,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2026-01-15T12:00:00Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2025-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (Jan 2026) - should cutoff at Jan 1, 2025");
assertEquals(
result,
expected,
"9001 retention (Jan 2026) - should cutoff at Jan 1, 2025"
);
}
// Test 5: Special case 9001 - December 31, 2025 at 23:59:59 UTC
@@ -68,7 +79,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2025-12-31T23:59:59Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2024-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (Dec 31, 2025 23:59:59) - should cutoff at Jan 1, 2024");
assertEquals(
result,
expected,
"9001 retention (Dec 31, 2025 23:59:59) - should cutoff at Jan 1, 2024"
);
}
// Test 6: Special case 9001 - January 1, 2026 at 00:00:01 UTC
@@ -78,7 +93,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2026-01-01T00:00:01Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2025-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (Jan 1, 2026 00:00:01) - should cutoff at Jan 1, 2025");
assertEquals(
result,
expected,
"9001 retention (Jan 1, 2026 00:00:01) - should cutoff at Jan 1, 2025"
);
}
// Test 7: Special case 9001 - Mid year 2025
@@ -87,7 +106,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2025-06-15T12:00:00Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2024-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (mid 2025) - should cutoff at Jan 1, 2024");
assertEquals(
result,
expected,
"9001 retention (mid 2025) - should cutoff at Jan 1, 2024"
);
}
// Test 8: Special case 9001 - Early 2024
@@ -96,14 +119,18 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2024-02-01T12:00:00Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2023-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (early 2024) - should cutoff at Jan 1, 2023");
assertEquals(
result,
expected,
"9001 retention (early 2024) - should cutoff at Jan 1, 2023"
);
}
// Test 9: 1 day retention
{
const now = dateToTimestamp("2025-12-06T12:00:00Z");
const result = calculateCutoffTimestampWithNow(1, now);
const expected = now - (1 * 24 * 60 * 60);
const expected = now - 1 * 24 * 60 * 60;
assertEquals(result, expected, "1 day retention calculation failed");
}
@@ -111,7 +138,7 @@ function testCalculateCutoffTimestamp() {
{
const now = dateToTimestamp("2025-12-06T12:00:00Z");
const result = calculateCutoffTimestampWithNow(365, now);
const expected = now - (365 * 24 * 60 * 60);
const expected = now - 365 * 24 * 60 * 60;
assertEquals(result, expected, "365 days retention calculation failed");
}
@@ -125,9 +152,17 @@ function testCalculateCutoffTimestamp() {
const logFromJan2024 = dateToTimestamp("2024-01-01T00:00:00Z");
// Log from Dec 2023 should be before cutoff (deleted)
assertEquals(logFromDec2023 < cutoff, true, "Log from Dec 2023 should be deleted");
assertEquals(
logFromDec2023 < cutoff,
true,
"Log from Dec 2023 should be deleted"
);
// Log from Jan 2024 should be at or after cutoff (kept)
assertEquals(logFromJan2024 >= cutoff, true, "Log from Jan 2024 should be kept");
assertEquals(
logFromJan2024 >= cutoff,
true,
"Log from Jan 2024 should be kept"
);
}
// Test 12: Verify 9001 in 2026 - logs from 2024 should now be deleted
@@ -138,9 +173,17 @@ function testCalculateCutoffTimestamp() {
const logFromJan2025 = dateToTimestamp("2025-01-01T00:00:00Z");
// Log from Dec 2024 should be before cutoff (deleted)
assertEquals(logFromDec2024 < cutoff, true, "Log from Dec 2024 should be deleted in 2026");
assertEquals(
logFromDec2024 < cutoff,
true,
"Log from Dec 2024 should be deleted in 2026"
);
// Log from Jan 2025 should be at or after cutoff (kept)
assertEquals(logFromJan2025 >= cutoff, true, "Log from Jan 2025 should be kept in 2026");
assertEquals(
logFromJan2025 >= cutoff,
true,
"Log from Jan 2025 should be kept in 2026"
);
}
// Test 13: Edge case - exactly at year boundary for 9001
@@ -149,7 +192,11 @@ function testCalculateCutoffTimestamp() {
const now = dateToTimestamp("2025-01-01T00:00:00Z");
const result = calculateCutoffTimestampWithNow(9001, now);
const expected = dateToTimestamp("2024-01-01T00:00:00Z");
assertEquals(result, expected, "9001 retention (Jan 1, 2025 00:00:00) - should cutoff at Jan 1, 2024");
assertEquals(
result,
expected,
"9001 retention (Jan 1, 2025 00:00:00) - should cutoff at Jan 1, 2024"
);
}
// Test 14: Verify data from 2024 is kept throughout 2025 when using 9001
@@ -157,18 +204,29 @@ function testCalculateCutoffTimestamp() {
{
// Running in June 2025
const nowJune2025 = dateToTimestamp("2025-06-15T12:00:00Z");
const cutoffJune2025 = calculateCutoffTimestampWithNow(9001, nowJune2025);
const cutoffJune2025 = calculateCutoffTimestampWithNow(
9001,
nowJune2025
);
const logFromJuly2024 = dateToTimestamp("2024-07-15T12:00:00Z");
// Log from July 2024 should be KEPT in June 2025
assertEquals(logFromJuly2024 >= cutoffJune2025, true, "Log from July 2024 should be kept in June 2025");
assertEquals(
logFromJuly2024 >= cutoffJune2025,
true,
"Log from July 2024 should be kept in June 2025"
);
// Running in January 2026
const nowJan2026 = dateToTimestamp("2026-01-15T12:00:00Z");
const cutoffJan2026 = calculateCutoffTimestampWithNow(9001, nowJan2026);
// Log from July 2024 should be DELETED in January 2026
assertEquals(logFromJuly2024 < cutoffJan2026, true, "Log from July 2024 should be deleted in Jan 2026");
assertEquals(
logFromJuly2024 < cutoffJan2026,
true,
"Log from July 2024 should be deleted in Jan 2026"
);
}
// Test 15: Verify the exact requirement - data from 2024 must be purged on December 31, 2025
@@ -179,13 +237,24 @@ function testCalculateCutoffTimestamp() {
// Dec 31, 2025 23:59:59 - still 2025, log should be kept
const nowDec31_2025 = dateToTimestamp("2025-12-31T23:59:59Z");
const cutoffDec31 = calculateCutoffTimestampWithNow(9001, nowDec31_2025);
assertEquals(logFromMid2024 >= cutoffDec31, true, "Log from mid-2024 should be kept on Dec 31, 2025");
const cutoffDec31 = calculateCutoffTimestampWithNow(
9001,
nowDec31_2025
);
assertEquals(
logFromMid2024 >= cutoffDec31,
true,
"Log from mid-2024 should be kept on Dec 31, 2025"
);
// Jan 1, 2026 00:00:00 - now 2026, log can be deleted
const nowJan1_2026 = dateToTimestamp("2026-01-01T00:00:00Z");
const cutoffJan1 = calculateCutoffTimestampWithNow(9001, nowJan1_2026);
assertEquals(logFromMid2024 < cutoffJan1, true, "Log from mid-2024 should be deleted on Jan 1, 2026");
assertEquals(
logFromMid2024 < cutoffJan1,
true,
"Log from mid-2024 should be deleted on Jan 1, 2026"
);
}
console.log("All calculateCutoffTimestamp tests passed!");

View File

@@ -4,11 +4,13 @@ import { eq, and } from "drizzle-orm";
import { subdomainSchema } from "@server/lib/schemas";
import { fromError } from "zod-validation-error";
export type DomainValidationResult = {
export type DomainValidationResult =
| {
success: true;
fullDomain: string;
subdomain: string | null;
} | {
}
| {
success: false;
error: string;
};
@@ -34,7 +36,10 @@ export async function validateAndConstructDomain(
.where(eq(domains.domainId, domainId))
.leftJoin(
orgDomains,
and(eq(orgDomains.orgId, orgId), eq(orgDomains.domainId, domainId))
and(
eq(orgDomains.orgId, orgId),
eq(orgDomains.domainId, domainId)
)
);
// Check if domain exists
@@ -106,7 +111,7 @@ export async function validateAndConstructDomain(
} catch (error) {
return {
success: false,
error: `An error occurred while validating domain: ${error instanceof Error ? error.message : 'Unknown error'}`
error: `An error occurred while validating domain: ${error instanceof Error ? error.message : "Unknown error"}`
};
}
}

View File

@@ -1,37 +1,37 @@
import crypto from 'crypto';
import crypto from "crypto";
export function encryptData(data: string, key: Buffer): string {
const algorithm = 'aes-256-gcm';
const algorithm = "aes-256-gcm";
const iv = crypto.randomBytes(16);
const cipher = crypto.createCipheriv(algorithm, key, iv);
let encrypted = cipher.update(data, 'utf8', 'hex');
encrypted += cipher.final('hex');
let encrypted = cipher.update(data, "utf8", "hex");
encrypted += cipher.final("hex");
const authTag = cipher.getAuthTag();
// Combine IV, auth tag, and encrypted data
return iv.toString('hex') + ':' + authTag.toString('hex') + ':' + encrypted;
return iv.toString("hex") + ":" + authTag.toString("hex") + ":" + encrypted;
}
// Helper function to decrypt data (you'll need this to read certificates)
export function decryptData(encryptedData: string, key: Buffer): string {
const algorithm = 'aes-256-gcm';
const parts = encryptedData.split(':');
const algorithm = "aes-256-gcm";
const parts = encryptedData.split(":");
if (parts.length !== 3) {
throw new Error('Invalid encrypted data format');
throw new Error("Invalid encrypted data format");
}
const iv = Buffer.from(parts[0], 'hex');
const authTag = Buffer.from(parts[1], 'hex');
const iv = Buffer.from(parts[0], "hex");
const authTag = Buffer.from(parts[1], "hex");
const encrypted = parts[2];
const decipher = crypto.createDecipheriv(algorithm, key, iv);
decipher.setAuthTag(authTag);
let decrypted = decipher.update(encrypted, 'hex', 'utf8');
decrypted += decipher.final('utf8');
let decrypted = decipher.update(encrypted, "hex", "utf8");
decrypted += decipher.final("utf8");
return decrypted;
}

View File

@@ -33,7 +33,11 @@ export async function generateOidcRedirectUrl(
)
.limit(1);
if (res?.loginPage && res.loginPage.domainId && res.loginPage.fullDomain) {
if (
res?.loginPage &&
res.loginPage.domainId &&
res.loginPage.fullDomain
) {
baseUrl = `${method}://${res.loginPage.fullDomain}`;
}
}

View File

@@ -23,7 +23,11 @@ function testFindNextAvailableCidr() {
{
const existing = ["10.0.0.0/16", "10.2.0.0/16"];
const result = findNextAvailableCidr(existing, 16, "10.0.0.0/8");
assertEquals(result, "10.1.0.0/16", "Finding gap between allocations failed");
assertEquals(
result,
"10.1.0.0/16",
"Finding gap between allocations failed"
);
}
// Test 3: No available space

View File

@@ -247,7 +247,10 @@ export async function getNextAvailableClientSubnet(
orgId: string,
transaction: Transaction | typeof db = db
): Promise<string> {
const [org] = await transaction.select().from(orgs).where(eq(orgs.orgId, orgId));
const [org] = await transaction
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId));
if (!org) {
throw new Error(`Organization with ID ${orgId} not found`);
@@ -360,7 +363,9 @@ export async function getNextAvailableOrgSubnet(): Promise<string> {
return subnet;
}
export function generateRemoteSubnets(allSiteResources: SiteResource[]): string[] {
export function generateRemoteSubnets(
allSiteResources: SiteResource[]
): string[] {
const remoteSubnets = allSiteResources
.filter((sr) => {
if (sr.mode === "cidr") return true;

View File

@@ -14,7 +14,8 @@ export const configSchema = z
.object({
app: z
.object({
dashboard_url: z.url()
dashboard_url: z
.url()
.pipe(z.url())
.transform((url) => url.toLowerCase())
.optional(),
@@ -255,7 +256,10 @@ export const configSchema = z
.object({
block_size: z.number().positive().gt(0).optional().default(24),
subnet_group: z.string().optional().default("100.90.128.0/24"),
utility_subnet_group: z.string().optional().default("100.96.128.0/24") //just hardcode this for now as well
utility_subnet_group: z
.string()
.optional()
.default("100.96.128.0/24") //just hardcode this for now as well
})
.optional()
.default({

View File

@@ -32,7 +32,7 @@ import logger from "@server/logger";
import {
generateAliasConfig,
generateRemoteSubnets,
generateSubnetProxyTargets,
generateSubnetProxyTargets
} from "@server/lib/ip";
import {
addPeerData,
@@ -109,7 +109,8 @@ export async function getClientSiteResourceAccess(
const directClientIds = allClientSiteResources.map((row) => row.clientId);
// Get full client details for directly associated clients
const directClients = directClientIds.length > 0
const directClients =
directClientIds.length > 0
? await trx
.select({
clientId: clients.clientId,
@@ -731,7 +732,8 @@ async function handleSubnetProxyTargetUpdates(
);
// Only remove remote subnet if no other resource uses the same destination
const remoteSubnetsToRemove = destinationStillInUse.length > 0
const remoteSubnetsToRemove =
destinationStillInUse.length > 0
? []
: generateRemoteSubnets([siteResource]);
@@ -817,7 +819,10 @@ export async function rebuildClientAssociationsFromClient(
.from(roleSiteResources)
.innerJoin(
siteResources,
eq(siteResources.siteResourceId, roleSiteResources.siteResourceId)
eq(
siteResources.siteResourceId,
roleSiteResources.siteResourceId
)
)
.where(
and(
@@ -1277,7 +1282,8 @@ async function handleMessagesForClientResources(
);
// Only remove remote subnet if no other resource uses the same destination
const remoteSubnetsToRemove = destinationStillInUse.length > 0
const remoteSubnetsToRemove =
destinationStillInUse.length > 0
? []
: generateRemoteSubnets([resource]);

View File

@@ -3,14 +3,14 @@ import { Response } from "express";
export const response = <T>(
res: Response,
{ data, success, error, message, status }: ResponseT<T>,
{ data, success, error, message, status }: ResponseT<T>
) => {
return res.status(status).send({
data,
success,
error,
message,
status,
status
});
};

View File

@@ -1,5 +1,5 @@
import { S3Client } from "@aws-sdk/client-s3";
export const s3Client = new S3Client({
region: process.env.S3_REGION || "us-east-1",
region: process.env.S3_REGION || "us-east-1"
});

View File

@@ -6,7 +6,7 @@ let serverIp: string | null = null;
const services = [
"https://checkip.amazonaws.com",
"https://ifconfig.io/ip",
"https://api.ipify.org",
"https://api.ipify.org"
];
export async function fetchServerIp() {
@@ -17,7 +17,9 @@ export async function fetchServerIp() {
logger.debug("Detected public IP: " + serverIp);
return;
} catch (err: any) {
console.warn(`Failed to fetch server IP from ${url}: ${err.message || err.code}`);
console.warn(
`Failed to fetch server IP from ${url}: ${err.message || err.code}`
);
}
}

View File

@@ -1,8 +1,7 @@
export default function stoi(val: any) {
if (typeof val === "string") {
return parseInt(val);
}
else {
} else {
return val;
}
}

View File

@@ -195,7 +195,9 @@ export class TraefikConfigManager {
state.set(domain, {
exists: certExists && keyExists,
lastModified: lastModified ? Math.floor(lastModified.getTime() / 1000) : null,
lastModified: lastModified
? Math.floor(lastModified.getTime() / 1000)
: null,
expiresAt,
wildcard
});
@@ -464,7 +466,9 @@ export class TraefikConfigManager {
config.getRawConfig().traefik.site_types,
build == "oss", // filter out the namespace domains in open source
build != "oss", // generate the login pages on the cloud and hybrid,
build == "saas" ? false : config.getRawConfig().traefik.allow_raw_resources // dont allow raw resources on saas otherwise use config
build == "saas"
? false
: config.getRawConfig().traefik.allow_raw_resources // dont allow raw resources on saas otherwise use config
);
const domains = new Set<string>();
@@ -788,7 +792,10 @@ export class TraefikConfigManager {
// Store the certificate expiry time
if (cert.expiresAt) {
const expiresAtPath = path.join(domainDir, ".expires_at");
const expiresAtPath = path.join(
domainDir,
".expires_at"
);
fs.writeFileSync(
expiresAtPath,
cert.expiresAt.toString(),

View File

@@ -2,234 +2,249 @@ import { assertEquals } from "@test/assert";
import { isDomainCoveredByWildcard } from "./TraefikConfigManager";
function runTests() {
console.log('Running wildcard domain coverage tests...');
console.log("Running wildcard domain coverage tests...");
// Test case 1: Basic wildcard certificate at example.com
const basicWildcardCerts = new Map([
['example.com', { exists: true, wildcard: true }]
["example.com", { exists: true, wildcard: true }]
]);
// Should match first-level subdomains
assertEquals(
isDomainCoveredByWildcard('level1.example.com', basicWildcardCerts),
isDomainCoveredByWildcard("level1.example.com", basicWildcardCerts),
true,
'Wildcard cert at example.com should match level1.example.com'
"Wildcard cert at example.com should match level1.example.com"
);
assertEquals(
isDomainCoveredByWildcard('api.example.com', basicWildcardCerts),
isDomainCoveredByWildcard("api.example.com", basicWildcardCerts),
true,
'Wildcard cert at example.com should match api.example.com'
"Wildcard cert at example.com should match api.example.com"
);
assertEquals(
isDomainCoveredByWildcard('www.example.com', basicWildcardCerts),
isDomainCoveredByWildcard("www.example.com", basicWildcardCerts),
true,
'Wildcard cert at example.com should match www.example.com'
"Wildcard cert at example.com should match www.example.com"
);
// Should match the root domain (exact match)
assertEquals(
isDomainCoveredByWildcard('example.com', basicWildcardCerts),
isDomainCoveredByWildcard("example.com", basicWildcardCerts),
true,
'Wildcard cert at example.com should match example.com itself'
"Wildcard cert at example.com should match example.com itself"
);
// Should NOT match second-level subdomains
assertEquals(
isDomainCoveredByWildcard('level2.level1.example.com', basicWildcardCerts),
isDomainCoveredByWildcard(
"level2.level1.example.com",
basicWildcardCerts
),
false,
'Wildcard cert at example.com should NOT match level2.level1.example.com'
"Wildcard cert at example.com should NOT match level2.level1.example.com"
);
assertEquals(
isDomainCoveredByWildcard('deep.nested.subdomain.example.com', basicWildcardCerts),
isDomainCoveredByWildcard(
"deep.nested.subdomain.example.com",
basicWildcardCerts
),
false,
'Wildcard cert at example.com should NOT match deep.nested.subdomain.example.com'
"Wildcard cert at example.com should NOT match deep.nested.subdomain.example.com"
);
// Should NOT match different domains
assertEquals(
isDomainCoveredByWildcard('test.otherdomain.com', basicWildcardCerts),
isDomainCoveredByWildcard("test.otherdomain.com", basicWildcardCerts),
false,
'Wildcard cert at example.com should NOT match test.otherdomain.com'
"Wildcard cert at example.com should NOT match test.otherdomain.com"
);
assertEquals(
isDomainCoveredByWildcard('notexample.com', basicWildcardCerts),
isDomainCoveredByWildcard("notexample.com", basicWildcardCerts),
false,
'Wildcard cert at example.com should NOT match notexample.com'
"Wildcard cert at example.com should NOT match notexample.com"
);
// Test case 2: Multiple wildcard certificates
const multipleWildcardCerts = new Map([
['example.com', { exists: true, wildcard: true }],
['test.org', { exists: true, wildcard: true }],
['api.service.net', { exists: true, wildcard: true }]
["example.com", { exists: true, wildcard: true }],
["test.org", { exists: true, wildcard: true }],
["api.service.net", { exists: true, wildcard: true }]
]);
assertEquals(
isDomainCoveredByWildcard('app.example.com', multipleWildcardCerts),
isDomainCoveredByWildcard("app.example.com", multipleWildcardCerts),
true,
'Should match subdomain of first wildcard cert'
"Should match subdomain of first wildcard cert"
);
assertEquals(
isDomainCoveredByWildcard('staging.test.org', multipleWildcardCerts),
isDomainCoveredByWildcard("staging.test.org", multipleWildcardCerts),
true,
'Should match subdomain of second wildcard cert'
"Should match subdomain of second wildcard cert"
);
assertEquals(
isDomainCoveredByWildcard('v1.api.service.net', multipleWildcardCerts),
isDomainCoveredByWildcard("v1.api.service.net", multipleWildcardCerts),
true,
'Should match subdomain of third wildcard cert'
"Should match subdomain of third wildcard cert"
);
assertEquals(
isDomainCoveredByWildcard('deep.nested.api.service.net', multipleWildcardCerts),
isDomainCoveredByWildcard(
"deep.nested.api.service.net",
multipleWildcardCerts
),
false,
'Should NOT match multi-level subdomain of third wildcard cert'
"Should NOT match multi-level subdomain of third wildcard cert"
);
// Test exact domain matches for multiple certs
assertEquals(
isDomainCoveredByWildcard('example.com', multipleWildcardCerts),
isDomainCoveredByWildcard("example.com", multipleWildcardCerts),
true,
'Should match exact domain of first wildcard cert'
"Should match exact domain of first wildcard cert"
);
assertEquals(
isDomainCoveredByWildcard('test.org', multipleWildcardCerts),
isDomainCoveredByWildcard("test.org", multipleWildcardCerts),
true,
'Should match exact domain of second wildcard cert'
"Should match exact domain of second wildcard cert"
);
assertEquals(
isDomainCoveredByWildcard('api.service.net', multipleWildcardCerts),
isDomainCoveredByWildcard("api.service.net", multipleWildcardCerts),
true,
'Should match exact domain of third wildcard cert'
"Should match exact domain of third wildcard cert"
);
// Test case 3: Non-wildcard certificates (should not match anything)
const nonWildcardCerts = new Map([
['example.com', { exists: true, wildcard: false }],
['specific.domain.com', { exists: true, wildcard: false }]
["example.com", { exists: true, wildcard: false }],
["specific.domain.com", { exists: true, wildcard: false }]
]);
assertEquals(
isDomainCoveredByWildcard('sub.example.com', nonWildcardCerts),
isDomainCoveredByWildcard("sub.example.com", nonWildcardCerts),
false,
'Non-wildcard cert should not match subdomains'
"Non-wildcard cert should not match subdomains"
);
assertEquals(
isDomainCoveredByWildcard('example.com', nonWildcardCerts),
isDomainCoveredByWildcard("example.com", nonWildcardCerts),
false,
'Non-wildcard cert should not match even exact domain via this function'
"Non-wildcard cert should not match even exact domain via this function"
);
// Test case 4: Non-existent certificates (should not match)
const nonExistentCerts = new Map([
['example.com', { exists: false, wildcard: true }],
['missing.com', { exists: false, wildcard: true }]
["example.com", { exists: false, wildcard: true }],
["missing.com", { exists: false, wildcard: true }]
]);
assertEquals(
isDomainCoveredByWildcard('sub.example.com', nonExistentCerts),
isDomainCoveredByWildcard("sub.example.com", nonExistentCerts),
false,
'Non-existent wildcard cert should not match'
"Non-existent wildcard cert should not match"
);
// Test case 5: Edge cases with special domain names
const specialDomainCerts = new Map([
['localhost', { exists: true, wildcard: true }],
['127-0-0-1.nip.io', { exists: true, wildcard: true }],
['xn--e1afmkfd.xn--p1ai', { exists: true, wildcard: true }] // IDN domain
["localhost", { exists: true, wildcard: true }],
["127-0-0-1.nip.io", { exists: true, wildcard: true }],
["xn--e1afmkfd.xn--p1ai", { exists: true, wildcard: true }] // IDN domain
]);
assertEquals(
isDomainCoveredByWildcard('app.localhost', specialDomainCerts),
isDomainCoveredByWildcard("app.localhost", specialDomainCerts),
true,
'Should match subdomain of localhost wildcard'
"Should match subdomain of localhost wildcard"
);
assertEquals(
isDomainCoveredByWildcard('test.127-0-0-1.nip.io', specialDomainCerts),
isDomainCoveredByWildcard("test.127-0-0-1.nip.io", specialDomainCerts),
true,
'Should match subdomain of nip.io wildcard'
"Should match subdomain of nip.io wildcard"
);
assertEquals(
isDomainCoveredByWildcard('sub.xn--e1afmkfd.xn--p1ai', specialDomainCerts),
isDomainCoveredByWildcard(
"sub.xn--e1afmkfd.xn--p1ai",
specialDomainCerts
),
true,
'Should match subdomain of IDN wildcard'
"Should match subdomain of IDN wildcard"
);
// Test case 6: Empty input and edge cases
const emptyCerts = new Map();
assertEquals(
isDomainCoveredByWildcard('any.domain.com', emptyCerts),
isDomainCoveredByWildcard("any.domain.com", emptyCerts),
false,
'Empty certificate map should not match any domain'
"Empty certificate map should not match any domain"
);
// Test case 7: Domains with single character components
const singleCharCerts = new Map([
['a.com', { exists: true, wildcard: true }],
['x.y.z', { exists: true, wildcard: true }]
["a.com", { exists: true, wildcard: true }],
["x.y.z", { exists: true, wildcard: true }]
]);
assertEquals(
isDomainCoveredByWildcard('b.a.com', singleCharCerts),
isDomainCoveredByWildcard("b.a.com", singleCharCerts),
true,
'Should match single character subdomain'
"Should match single character subdomain"
);
assertEquals(
isDomainCoveredByWildcard('w.x.y.z', singleCharCerts),
isDomainCoveredByWildcard("w.x.y.z", singleCharCerts),
true,
'Should match single character subdomain of multi-part domain'
"Should match single character subdomain of multi-part domain"
);
assertEquals(
isDomainCoveredByWildcard('v.w.x.y.z', singleCharCerts),
isDomainCoveredByWildcard("v.w.x.y.z", singleCharCerts),
false,
'Should NOT match multi-level subdomain of single char domain'
"Should NOT match multi-level subdomain of single char domain"
);
// Test case 8: Domains with numbers and hyphens
const numericCerts = new Map([
['api-v2.service-1.com', { exists: true, wildcard: true }],
['123.456.net', { exists: true, wildcard: true }]
["api-v2.service-1.com", { exists: true, wildcard: true }],
["123.456.net", { exists: true, wildcard: true }]
]);
assertEquals(
isDomainCoveredByWildcard('staging.api-v2.service-1.com', numericCerts),
isDomainCoveredByWildcard("staging.api-v2.service-1.com", numericCerts),
true,
'Should match subdomain with hyphens and numbers'
"Should match subdomain with hyphens and numbers"
);
assertEquals(
isDomainCoveredByWildcard('test.123.456.net', numericCerts),
isDomainCoveredByWildcard("test.123.456.net", numericCerts),
true,
'Should match subdomain with numeric components'
"Should match subdomain with numeric components"
);
assertEquals(
isDomainCoveredByWildcard('deep.staging.api-v2.service-1.com', numericCerts),
isDomainCoveredByWildcard(
"deep.staging.api-v2.service-1.com",
numericCerts
),
false,
'Should NOT match multi-level subdomain with hyphens and numbers'
"Should NOT match multi-level subdomain with hyphens and numbers"
);
console.log('All wildcard domain coverage tests passed!');
console.log("All wildcard domain coverage tests passed!");
}
// Run all tests
try {
runTests();
} catch (error) {
console.error('Test failed:', error);
console.error("Test failed:", error);
process.exit(1);
}

View File

@@ -31,12 +31,17 @@ export function validatePathRewriteConfig(
}
if (rewritePathType !== "stripPrefix") {
if ((rewritePath && !rewritePathType) || (!rewritePath && rewritePathType)) {
return { isValid: false, error: "Both rewritePath and rewritePathType must be specified together" };
if (
(rewritePath && !rewritePathType) ||
(!rewritePath && rewritePathType)
) {
return {
isValid: false,
error: "Both rewritePath and rewritePathType must be specified together"
};
}
}
if (!rewritePath || !rewritePathType) {
return { isValid: true };
}
@@ -68,14 +73,14 @@ export function validatePathRewriteConfig(
}
}
// Additional validation for stripPrefix
if (rewritePathType === "stripPrefix") {
if (pathMatchType !== "prefix") {
logger.warn(`stripPrefix rewrite type is most effective with prefix path matching. Current match type: ${pathMatchType}`);
logger.warn(
`stripPrefix rewrite type is most effective with prefix path matching. Current match type: ${pathMatchType}`
);
}
}
return { isValid: true };
}

View File

@@ -2,70 +2,246 @@ import { isValidUrlGlobPattern } from "./validators";
import { assertEquals } from "@test/assert";
function runTests() {
console.log('Running URL pattern validation tests...');
console.log("Running URL pattern validation tests...");
// Test valid patterns
assertEquals(isValidUrlGlobPattern('simple'), true, 'Simple path segment should be valid');
assertEquals(isValidUrlGlobPattern('simple/path'), true, 'Simple path with slash should be valid');
assertEquals(isValidUrlGlobPattern('/leading/slash'), true, 'Path with leading slash should be valid');
assertEquals(isValidUrlGlobPattern('path/'), true, 'Path with trailing slash should be valid');
assertEquals(isValidUrlGlobPattern('path/*'), true, 'Path with wildcard segment should be valid');
assertEquals(isValidUrlGlobPattern('*'), true, 'Single wildcard should be valid');
assertEquals(isValidUrlGlobPattern('*/subpath'), true, 'Wildcard with subpath should be valid');
assertEquals(isValidUrlGlobPattern('path/*/more'), true, 'Path with wildcard in the middle should be valid');
assertEquals(
isValidUrlGlobPattern("simple"),
true,
"Simple path segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("simple/path"),
true,
"Simple path with slash should be valid"
);
assertEquals(
isValidUrlGlobPattern("/leading/slash"),
true,
"Path with leading slash should be valid"
);
assertEquals(
isValidUrlGlobPattern("path/"),
true,
"Path with trailing slash should be valid"
);
assertEquals(
isValidUrlGlobPattern("path/*"),
true,
"Path with wildcard segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("*"),
true,
"Single wildcard should be valid"
);
assertEquals(
isValidUrlGlobPattern("*/subpath"),
true,
"Wildcard with subpath should be valid"
);
assertEquals(
isValidUrlGlobPattern("path/*/more"),
true,
"Path with wildcard in the middle should be valid"
);
// Test with special characters
assertEquals(isValidUrlGlobPattern('path-with-dash'), true, 'Path with dash should be valid');
assertEquals(isValidUrlGlobPattern('path_with_underscore'), true, 'Path with underscore should be valid');
assertEquals(isValidUrlGlobPattern('path.with.dots'), true, 'Path with dots should be valid');
assertEquals(isValidUrlGlobPattern('path~with~tilde'), true, 'Path with tilde should be valid');
assertEquals(isValidUrlGlobPattern('path!with!exclamation'), true, 'Path with exclamation should be valid');
assertEquals(isValidUrlGlobPattern('path$with$dollar'), true, 'Path with dollar should be valid');
assertEquals(isValidUrlGlobPattern('path&with&ampersand'), true, 'Path with ampersand should be valid');
assertEquals(isValidUrlGlobPattern("path'with'quote"), true, "Path with quote should be valid");
assertEquals(isValidUrlGlobPattern('path(with)parentheses'), true, 'Path with parentheses should be valid');
assertEquals(isValidUrlGlobPattern('path+with+plus'), true, 'Path with plus should be valid');
assertEquals(isValidUrlGlobPattern('path,with,comma'), true, 'Path with comma should be valid');
assertEquals(isValidUrlGlobPattern('path;with;semicolon'), true, 'Path with semicolon should be valid');
assertEquals(isValidUrlGlobPattern('path=with=equals'), true, 'Path with equals should be valid');
assertEquals(isValidUrlGlobPattern('path:with:colon'), true, 'Path with colon should be valid');
assertEquals(isValidUrlGlobPattern('path@with@at'), true, 'Path with at should be valid');
assertEquals(
isValidUrlGlobPattern("path-with-dash"),
true,
"Path with dash should be valid"
);
assertEquals(
isValidUrlGlobPattern("path_with_underscore"),
true,
"Path with underscore should be valid"
);
assertEquals(
isValidUrlGlobPattern("path.with.dots"),
true,
"Path with dots should be valid"
);
assertEquals(
isValidUrlGlobPattern("path~with~tilde"),
true,
"Path with tilde should be valid"
);
assertEquals(
isValidUrlGlobPattern("path!with!exclamation"),
true,
"Path with exclamation should be valid"
);
assertEquals(
isValidUrlGlobPattern("path$with$dollar"),
true,
"Path with dollar should be valid"
);
assertEquals(
isValidUrlGlobPattern("path&with&ampersand"),
true,
"Path with ampersand should be valid"
);
assertEquals(
isValidUrlGlobPattern("path'with'quote"),
true,
"Path with quote should be valid"
);
assertEquals(
isValidUrlGlobPattern("path(with)parentheses"),
true,
"Path with parentheses should be valid"
);
assertEquals(
isValidUrlGlobPattern("path+with+plus"),
true,
"Path with plus should be valid"
);
assertEquals(
isValidUrlGlobPattern("path,with,comma"),
true,
"Path with comma should be valid"
);
assertEquals(
isValidUrlGlobPattern("path;with;semicolon"),
true,
"Path with semicolon should be valid"
);
assertEquals(
isValidUrlGlobPattern("path=with=equals"),
true,
"Path with equals should be valid"
);
assertEquals(
isValidUrlGlobPattern("path:with:colon"),
true,
"Path with colon should be valid"
);
assertEquals(
isValidUrlGlobPattern("path@with@at"),
true,
"Path with at should be valid"
);
// Test with percent encoding
assertEquals(isValidUrlGlobPattern('path%20with%20spaces'), true, 'Path with percent-encoded spaces should be valid');
assertEquals(isValidUrlGlobPattern('path%2Fwith%2Fencoded%2Fslashes'), true, 'Path with percent-encoded slashes should be valid');
assertEquals(
isValidUrlGlobPattern("path%20with%20spaces"),
true,
"Path with percent-encoded spaces should be valid"
);
assertEquals(
isValidUrlGlobPattern("path%2Fwith%2Fencoded%2Fslashes"),
true,
"Path with percent-encoded slashes should be valid"
);
// Test with wildcards in segments (the fixed functionality)
assertEquals(isValidUrlGlobPattern('padbootstrap*'), true, 'Path with wildcard at the end of segment should be valid');
assertEquals(isValidUrlGlobPattern('pad*bootstrap'), true, 'Path with wildcard in the middle of segment should be valid');
assertEquals(isValidUrlGlobPattern('*bootstrap'), true, 'Path with wildcard at the start of segment should be valid');
assertEquals(isValidUrlGlobPattern('multiple*wildcards*in*segment'), true, 'Path with multiple wildcards in segment should be valid');
assertEquals(isValidUrlGlobPattern('wild*/cards/in*/different/seg*ments'), true, 'Path with wildcards in different segments should be valid');
assertEquals(
isValidUrlGlobPattern("padbootstrap*"),
true,
"Path with wildcard at the end of segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("pad*bootstrap"),
true,
"Path with wildcard in the middle of segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("*bootstrap"),
true,
"Path with wildcard at the start of segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("multiple*wildcards*in*segment"),
true,
"Path with multiple wildcards in segment should be valid"
);
assertEquals(
isValidUrlGlobPattern("wild*/cards/in*/different/seg*ments"),
true,
"Path with wildcards in different segments should be valid"
);
// Test invalid patterns
assertEquals(isValidUrlGlobPattern(''), false, 'Empty string should be invalid');
assertEquals(isValidUrlGlobPattern('//double/slash'), false, 'Path with double slash should be invalid');
assertEquals(isValidUrlGlobPattern('path//end'), false, 'Path with double slash in the middle should be invalid');
assertEquals(isValidUrlGlobPattern('invalid<char>'), false, 'Path with invalid characters should be invalid');
assertEquals(isValidUrlGlobPattern('invalid|char'), false, 'Path with invalid pipe character should be invalid');
assertEquals(isValidUrlGlobPattern('invalid"char'), false, 'Path with invalid quote character should be invalid');
assertEquals(isValidUrlGlobPattern('invalid`char'), false, 'Path with invalid backtick character should be invalid');
assertEquals(isValidUrlGlobPattern('invalid^char'), false, 'Path with invalid caret character should be invalid');
assertEquals(isValidUrlGlobPattern('invalid\\char'), false, 'Path with invalid backslash character should be invalid');
assertEquals(isValidUrlGlobPattern('invalid[char]'), false, 'Path with invalid square brackets should be invalid');
assertEquals(isValidUrlGlobPattern('invalid{char}'), false, 'Path with invalid curly braces should be invalid');
assertEquals(
isValidUrlGlobPattern(""),
false,
"Empty string should be invalid"
);
assertEquals(
isValidUrlGlobPattern("//double/slash"),
false,
"Path with double slash should be invalid"
);
assertEquals(
isValidUrlGlobPattern("path//end"),
false,
"Path with double slash in the middle should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid<char>"),
false,
"Path with invalid characters should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid|char"),
false,
"Path with invalid pipe character should be invalid"
);
assertEquals(
isValidUrlGlobPattern('invalid"char'),
false,
"Path with invalid quote character should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid`char"),
false,
"Path with invalid backtick character should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid^char"),
false,
"Path with invalid caret character should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid\\char"),
false,
"Path with invalid backslash character should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid[char]"),
false,
"Path with invalid square brackets should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid{char}"),
false,
"Path with invalid curly braces should be invalid"
);
// Test invalid percent encoding
assertEquals(isValidUrlGlobPattern('invalid%2'), false, 'Path with incomplete percent encoding should be invalid');
assertEquals(isValidUrlGlobPattern('invalid%GZ'), false, 'Path with invalid hex in percent encoding should be invalid');
assertEquals(isValidUrlGlobPattern('invalid%'), false, 'Path with isolated percent sign should be invalid');
assertEquals(
isValidUrlGlobPattern("invalid%2"),
false,
"Path with incomplete percent encoding should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid%GZ"),
false,
"Path with invalid hex in percent encoding should be invalid"
);
assertEquals(
isValidUrlGlobPattern("invalid%"),
false,
"Path with isolated percent sign should be invalid"
);
console.log('All tests passed!');
console.log("All tests passed!");
}
// Run all tests
try {
runTests();
} catch (error) {
console.error('Test failed:', error);
console.error("Test failed:", error);
}

View File

@@ -2,7 +2,9 @@ import z from "zod";
import ipaddr from "ipaddr.js";
export function isValidCIDR(cidr: string): boolean {
return z.cidrv4().safeParse(cidr).success || z.cidrv6().safeParse(cidr).success;
return (
z.cidrv4().safeParse(cidr).success || z.cidrv6().safeParse(cidr).success
);
}
export function isValidIP(ip: string): boolean {
@@ -168,14 +170,14 @@ export function validateHeaders(headers: string): boolean {
}
export function isSecondLevelDomain(domain: string): boolean {
if (!domain || typeof domain !== 'string') {
if (!domain || typeof domain !== "string") {
return false;
}
const trimmedDomain = domain.trim().toLowerCase();
// Split into parts
const parts = trimmedDomain.split('.');
const parts = trimmedDomain.split(".");
// Should have exactly 2 parts for a second-level domain (e.g., "example.com")
if (parts.length !== 2) {

View File

@@ -20,6 +20,6 @@ export const errorHandlerMiddleware: ErrorRequestHandler = (
error: true,
message: error.message || "Internal Server Error",
status: statusCode,
stack: process.env.ENVIRONMENT === "prod" ? null : error.stack,
stack: process.env.ENVIRONMENT === "prod" ? null : error.stack
});
};

View File

@@ -8,13 +8,13 @@ import HttpCode from "@server/types/HttpCode";
export async function getUserOrgs(
req: Request,
res: Response,
next: NextFunction,
next: NextFunction
) {
const userId = req.user?.userId; // Assuming you have user information in the request
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated"),
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
);
}
@@ -22,7 +22,7 @@ export async function getUserOrgs(
const userOrganizations = await db
.select({
orgId: userOrgs.orgId,
roleId: userOrgs.roleId,
roleId: userOrgs.roleId
})
.from(userOrgs)
.where(eq(userOrgs.userId, userId));
@@ -38,8 +38,8 @@ export async function getUserOrgs(
next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error retrieving user organizations",
),
"Error retrieving user organizations"
)
);
}
}

View File

@@ -97,7 +97,6 @@ export async function verifyApiKeyAccessTokenAccess(
);
}
return next();
} catch (e) {
return next(

View File

@@ -44,7 +44,10 @@ export async function verifyApiKeyApiKeyAccess(
.select()
.from(apiKeyOrg)
.where(
and(eq(apiKeys.apiKeyId, callerApiKey.apiKeyId), eq(apiKeyOrg.orgId, orgId))
and(
eq(apiKeys.apiKeyId, callerApiKey.apiKeyId),
eq(apiKeyOrg.orgId, orgId)
)
)
.limit(1);

View File

@@ -11,9 +11,12 @@ export async function verifyApiKeySetResourceClients(
next: NextFunction
) {
const apiKey = req.apiKey;
const singleClientId = req.params.clientId || req.body.clientId || req.query.clientId;
const singleClientId =
req.params.clientId || req.body.clientId || req.query.clientId;
const { clientIds } = req.body;
const allClientIds = clientIds || (singleClientId ? [parseInt(singleClientId as string)] : []);
const allClientIds =
clientIds ||
(singleClientId ? [parseInt(singleClientId as string)] : []);
if (!apiKey) {
return next(
@@ -70,4 +73,3 @@ export async function verifyApiKeySetResourceClients(
);
}
}

View File

@@ -11,7 +11,8 @@ export async function verifyApiKeySetResourceUsers(
next: NextFunction
) {
const apiKey = req.apiKey;
const singleUserId = req.params.userId || req.body.userId || req.query.userId;
const singleUserId =
req.params.userId || req.body.userId || req.query.userId;
const { userIds } = req.body;
const allUserIds = userIds || (singleUserId ? [singleUserId] : []);

View File

@@ -38,17 +38,12 @@ export async function verifyApiKeySiteResourceAccess(
const [siteResource] = await db
.select()
.from(siteResources)
.where(and(
eq(siteResources.siteResourceId, siteResourceId)
))
.where(and(eq(siteResources.siteResourceId, siteResourceId)))
.limit(1);
if (!siteResource) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Site resource not found"
)
createHttpError(HttpCode.NOT_FOUND, "Site resource not found")
);
}

View File

@@ -5,7 +5,7 @@ import HttpCode from "@server/types/HttpCode";
export function notFoundMiddleware(
req: Request,
res: Response,
next: NextFunction,
next: NextFunction
) {
if (req.path.startsWith("/api")) {
const message = `The requests url is not found - ${req.originalUrl}`;

View File

@@ -1,30 +1,32 @@
import { Request, Response, NextFunction } from 'express';
import logger from '@server/logger';
import createHttpError from 'http-errors';
import HttpCode from '@server/types/HttpCode';
import { Request, Response, NextFunction } from "express";
import logger from "@server/logger";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
export function requestTimeoutMiddleware(timeoutMs: number = 30000) {
return (req: Request, res: Response, next: NextFunction) => {
// Set a timeout for the request
const timeout = setTimeout(() => {
if (!res.headersSent) {
logger.error(`Request timeout: ${req.method} ${req.url} from ${req.ip}`);
logger.error(
`Request timeout: ${req.method} ${req.url} from ${req.ip}`
);
return next(
createHttpError(
HttpCode.REQUEST_TIMEOUT,
'Request timeout - operation took too long to complete'
"Request timeout - operation took too long to complete"
)
);
}
}, timeoutMs);
// Clear timeout when response finishes
res.on('finish', () => {
res.on("finish", () => {
clearTimeout(timeout);
});
// Clear timeout when response closes
res.on('close', () => {
res.on("close", () => {
clearTimeout(timeout);
});

View File

@@ -76,7 +76,10 @@ export async function verifySiteAccess(
.select()
.from(userOrgs)
.where(
and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, site.orgId))
and(
eq(userOrgs.userId, userId),
eq(userOrgs.orgId, site.orgId)
)
)
.limit(1);
req.userOrg = userOrgRole[0];

View File

@@ -9,7 +9,10 @@ const nextPort = config.getRawConfig().server.next_port;
export async function createNextServer() {
// const app = next({ dev });
const app = next({ dev: process.env.ENVIRONMENT !== "prod", turbopack: true });
const app = next({
dev: process.env.ENVIRONMENT !== "prod",
turbopack: true
});
const handle = app.getRequestHandler();
await app.prepare();

View File

@@ -11,11 +11,14 @@
* This file is not licensed under the AGPLv3.
*/
import {
encodeHexLowerCase,
} from "@oslojs/encoding";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { RemoteExitNode, remoteExitNodes, remoteExitNodeSessions, RemoteExitNodeSession } from "@server/db";
import {
RemoteExitNode,
remoteExitNodes,
remoteExitNodeSessions,
RemoteExitNodeSession
} from "@server/db";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
@@ -23,30 +26,39 @@ export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
export async function createRemoteExitNodeSession(
token: string,
remoteExitNodeId: string,
remoteExitNodeId: string
): Promise<RemoteExitNodeSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const session: RemoteExitNodeSession = {
sessionId: sessionId,
remoteExitNodeId,
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
expiresAt: new Date(Date.now() + EXPIRES).getTime()
};
await db.insert(remoteExitNodeSessions).values(session);
return session;
}
export async function validateRemoteExitNodeSessionToken(
token: string,
token: string
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const result = await db
.select({ remoteExitNode: remoteExitNodes, session: remoteExitNodeSessions })
.select({
remoteExitNode: remoteExitNodes,
session: remoteExitNodeSessions
})
.from(remoteExitNodeSessions)
.innerJoin(remoteExitNodes, eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodes.remoteExitNodeId))
.innerJoin(
remoteExitNodes,
eq(
remoteExitNodeSessions.remoteExitNodeId,
remoteExitNodes.remoteExitNodeId
)
)
.where(eq(remoteExitNodeSessions.sessionId, sessionId));
if (result.length < 1) {
return { session: null, remoteExitNode: null };
@@ -58,26 +70,32 @@ export async function validateRemoteExitNodeSessionToken(
.where(eq(remoteExitNodeSessions.sessionId, session.sessionId));
return { session: null, remoteExitNode: null };
}
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + EXPIRES,
).getTime();
if (Date.now() >= session.expiresAt - EXPIRES / 2) {
session.expiresAt = new Date(Date.now() + EXPIRES).getTime();
await db
.update(remoteExitNodeSessions)
.set({
expiresAt: session.expiresAt,
expiresAt: session.expiresAt
})
.where(eq(remoteExitNodeSessions.sessionId, session.sessionId));
}
return { session, remoteExitNode };
}
export async function invalidateRemoteExitNodeSession(sessionId: string): Promise<void> {
await db.delete(remoteExitNodeSessions).where(eq(remoteExitNodeSessions.sessionId, sessionId));
export async function invalidateRemoteExitNodeSession(
sessionId: string
): Promise<void> {
await db
.delete(remoteExitNodeSessions)
.where(eq(remoteExitNodeSessions.sessionId, sessionId));
}
export async function invalidateAllRemoteExitNodeSessions(remoteExitNodeId: string): Promise<void> {
await db.delete(remoteExitNodeSessions).where(eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodeId));
export async function invalidateAllRemoteExitNodeSessions(
remoteExitNodeId: string
): Promise<void> {
await db
.delete(remoteExitNodeSessions)
.where(eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodeId));
}
export type SessionValidationResult =

View File

@@ -55,7 +55,6 @@ export async function getValidCertificatesForDomains(
domains: Set<string>,
useCache: boolean = true
): Promise<Array<CertificateResult>> {
loadEncryptData(); // Ensure encryption key is loaded
const finalResults: CertificateResult[] = [];

View File

@@ -12,14 +12,7 @@
*/
import { build } from "@server/build";
import {
db,
Org,
orgs,
ResourceSession,
sessions,
users
} from "@server/db";
import { db, Org, orgs, ResourceSession, sessions, users } from "@server/db";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import license from "#private/license/license";

View File

@@ -66,7 +66,9 @@ export async function sendToExitNode(
// logger.debug(`Configured local exit node name: ${config.getRawConfig().gerbil.exit_node_name}`);
if (exitNode.name == config.getRawConfig().gerbil.exit_node_name) {
hostname = privateConfig.getRawPrivateConfig().gerbil.local_exit_node_reachable_at;
hostname =
privateConfig.getRawPrivateConfig().gerbil
.local_exit_node_reachable_at;
}
if (!hostname) {

View File

@@ -44,12 +44,16 @@ async function checkExitNodeOnlineStatus(
const delayBetweenAttempts = 100; // 100ms delay between starting each attempt
// Create promises for all attempts with staggered delays
const attemptPromises = Array.from({ length: maxAttempts }, async (_, index) => {
const attemptPromises = Array.from(
{ length: maxAttempts },
async (_, index) => {
const attemptNumber = index + 1;
// Add delay before each attempt (except the first)
if (index > 0) {
await new Promise((resolve) => setTimeout(resolve, delayBetweenAttempts * index));
await new Promise((resolve) =>
setTimeout(resolve, delayBetweenAttempts * index)
);
}
try {
@@ -64,15 +68,21 @@ async function checkExitNodeOnlineStatus(
);
return { success: true, attemptNumber };
}
return { success: false, attemptNumber, error: 'Non-200 status' };
return {
success: false,
attemptNumber,
error: "Non-200 status"
};
} catch (error) {
const errorMessage = error instanceof Error ? error.message : "Unknown error";
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
logger.debug(
`Exit node ${endpoint} ping failed (attempt ${attemptNumber}/${maxAttempts}): ${errorMessage}`
);
return { success: false, attemptNumber, error: errorMessage };
}
});
}
);
try {
// Wait for the first successful response or all to fail
@@ -80,7 +90,7 @@ async function checkExitNodeOnlineStatus(
// Check if any attempt succeeded
for (const result of results) {
if (result.status === 'fulfilled' && result.value.success) {
if (result.status === "fulfilled" && result.value.success) {
return true;
}
}
@@ -137,7 +147,11 @@ export async function verifyExitNodeOrgAccess(
return { hasAccess: false, exitNode };
}
export async function listExitNodes(orgId: string, filterOnline = false, noCloud = false) {
export async function listExitNodes(
orgId: string,
filterOnline = false,
noCloud = false
) {
const allExitNodes = await db
.select({
exitNodeId: exitNodes.exitNodeId,
@@ -166,7 +180,10 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
eq(exitNodes.type, "gerbil"),
or(
// only choose nodes that are in the same region
eq(exitNodes.region, config.getRawPrivateConfig().app.region),
eq(
exitNodes.region,
config.getRawPrivateConfig().app.region
),
isNull(exitNodes.region) // or for enterprise where region is not set
)
),
@@ -225,7 +242,8 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
node.type === "remoteExitNode" && (!filterOnline || node.online)
);
const gerbilExitNodes = allExitNodes.filter(
(node) => node.type === "gerbil" && (!filterOnline || node.online) && !noCloud
(node) =>
node.type === "gerbil" && (!filterOnline || node.online) && !noCloud
);
// THIS PROVIDES THE FALL
@@ -334,7 +352,11 @@ export function selectBestExitNode(
return fallbackNode;
}
export async function checkExitNodeOrg(exitNodeId: number, orgId: string, trx: Transaction | typeof db = db) {
export async function checkExitNodeOrg(
exitNodeId: number,
orgId: string,
trx: Transaction | typeof db = db
) {
const [exitNodeOrg] = await trx
.select()
.from(exitNodeOrgs)

View File

@@ -177,7 +177,9 @@ export class LockManager {
const exists = value !== null;
const ownedByMe =
exists &&
value!.startsWith(`${config.getRawConfig().gerbil.exit_node_name}:`);
value!.startsWith(
`${config.getRawConfig().gerbil.exit_node_name}:`
);
const owner = exists ? value!.split(":")[0] : undefined;
return {

View File

@@ -14,14 +14,14 @@
// Simple test file for the rate limit service with Redis
// Run with: npx ts-node rateLimitService.test.ts
import { RateLimitService } from './rateLimit';
import { RateLimitService } from "./rateLimit";
function generateClientId() {
return 'client-' + Math.random().toString(36).substring(2, 15);
return "client-" + Math.random().toString(36).substring(2, 15);
}
async function runTests() {
console.log('Starting Rate Limit Service Tests...\n');
console.log("Starting Rate Limit Service Tests...\n");
const rateLimitService = new RateLimitService();
let testsPassed = 0;
@@ -47,36 +47,54 @@ async function runTests() {
}
// Test 1: Basic rate limiting
await test('Should allow requests under the limit', async () => {
await test("Should allow requests under the limit", async () => {
const clientId = generateClientId();
const maxRequests = 5;
for (let i = 0; i < maxRequests - 1; i++) {
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
const result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
assert(result.totalHits === i + 1, `Expected ${i + 1} hits, got ${result.totalHits}`);
assert(
result.totalHits === i + 1,
`Expected ${i + 1} hits, got ${result.totalHits}`
);
}
});
// Test 2: Rate limit blocking
await test('Should block requests over the limit', async () => {
await test("Should block requests over the limit", async () => {
const clientId = generateClientId();
const maxRequests = 30;
// Use up all allowed requests
for (let i = 0; i < maxRequests - 1; i++) {
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
const result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
}
// Next request should be blocked
const blockedResult = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(blockedResult.isLimited, 'Request should be blocked');
assert(blockedResult.reason === 'global', 'Should be blocked for global reason');
const blockedResult = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(blockedResult.isLimited, "Request should be blocked");
assert(
blockedResult.reason === "global",
"Should be blocked for global reason"
);
});
// Test 3: Message type limits
await test('Should handle message type limits', async () => {
await test("Should handle message type limits", async () => {
const clientId = generateClientId();
const globalMax = 10;
const messageTypeMax = 2;
@@ -85,53 +103,63 @@ async function runTests() {
for (let i = 0; i < messageTypeMax - 1; i++) {
const result = await rateLimitService.checkRateLimit(
clientId,
'ping',
"ping",
globalMax,
messageTypeMax
);
assert(!result.isLimited, `Ping message ${i + 1} should be allowed`);
assert(
!result.isLimited,
`Ping message ${i + 1} should be allowed`
);
}
// Next 'ping' should be blocked
const blockedResult = await rateLimitService.checkRateLimit(
clientId,
'ping',
"ping",
globalMax,
messageTypeMax
);
assert(blockedResult.isLimited, 'Ping message should be blocked');
assert(blockedResult.reason === 'message_type:ping', 'Should be blocked for message type');
assert(blockedResult.isLimited, "Ping message should be blocked");
assert(
blockedResult.reason === "message_type:ping",
"Should be blocked for message type"
);
// Other message types should still work
const otherResult = await rateLimitService.checkRateLimit(
clientId,
'pong',
"pong",
globalMax,
messageTypeMax
);
assert(!otherResult.isLimited, 'Pong message should be allowed');
assert(!otherResult.isLimited, "Pong message should be allowed");
});
// Test 4: Reset functionality
await test('Should reset client correctly', async () => {
await test("Should reset client correctly", async () => {
const clientId = generateClientId();
const maxRequests = 3;
// Use up some requests
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
await rateLimitService.checkRateLimit(clientId, 'test', maxRequests);
await rateLimitService.checkRateLimit(clientId, "test", maxRequests);
// Reset the client
await rateLimitService.resetKey(clientId);
// Should be able to make fresh requests
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(!result.isLimited, 'Request after reset should be allowed');
assert(result.totalHits === 1, 'Should have 1 hit after reset');
const result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(!result.isLimited, "Request after reset should be allowed");
assert(result.totalHits === 1, "Should have 1 hit after reset");
});
// Test 5: Different clients are independent
await test('Should handle different clients independently', async () => {
await test("Should handle different clients independently", async () => {
const client1 = generateClientId();
const client2 = generateClientId();
const maxRequests = 2;
@@ -139,43 +167,62 @@ async function runTests() {
// Client 1 uses up their limit
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
const client1Blocked = await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
assert(client1Blocked.isLimited, 'Client 1 should be blocked');
const client1Blocked = await rateLimitService.checkRateLimit(
client1,
undefined,
maxRequests
);
assert(client1Blocked.isLimited, "Client 1 should be blocked");
// Client 2 should still be able to make requests
const client2Result = await rateLimitService.checkRateLimit(client2, undefined, maxRequests);
assert(!client2Result.isLimited, 'Client 2 should not be blocked');
assert(client2Result.totalHits === 1, 'Client 2 should have 1 hit');
const client2Result = await rateLimitService.checkRateLimit(
client2,
undefined,
maxRequests
);
assert(!client2Result.isLimited, "Client 2 should not be blocked");
assert(client2Result.totalHits === 1, "Client 2 should have 1 hit");
});
// Test 6: Decrement functionality
await test('Should decrement correctly', async () => {
await test("Should decrement correctly", async () => {
const clientId = generateClientId();
const maxRequests = 5;
// Make some requests
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
let result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(result.totalHits === 3, 'Should have 3 hits before decrement');
let result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(result.totalHits === 3, "Should have 3 hits before decrement");
// Decrement
await rateLimitService.decrementRateLimit(clientId);
// Next request should reflect the decrement
result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(result.totalHits === 3, 'Should have 3 hits after decrement + increment');
result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(
result.totalHits === 3,
"Should have 3 hits after decrement + increment"
);
});
// Wait a moment for any pending Redis operations
console.log('\nWaiting for Redis sync...');
await new Promise(resolve => setTimeout(resolve, 1000));
console.log("\nWaiting for Redis sync...");
await new Promise((resolve) => setTimeout(resolve, 1000));
// Force sync to test Redis integration
await test('Should sync to Redis', async () => {
await test("Should sync to Redis", async () => {
await rateLimitService.forceSyncAllPendingData();
// If this doesn't throw, Redis sync is working
assert(true, 'Redis sync completed');
assert(true, "Redis sync completed");
});
// Cleanup
@@ -187,16 +234,16 @@ async function runTests() {
console.log(`❌ Failed: ${testsTotal - testsPassed}/${testsTotal}`);
if (testsPassed === testsTotal) {
console.log('\n🎉 All tests passed!');
console.log("\n🎉 All tests passed!");
process.exit(0);
} else {
console.log('\n💥 Some tests failed!');
console.log("\n💥 Some tests failed!");
process.exit(1);
}
}
// Run the tests
runTests().catch(error => {
console.error('Test runner error:', error);
runTests().catch((error) => {
console.error("Test runner error:", error);
process.exit(1);
});

View File

@@ -40,7 +40,8 @@ interface RateLimitResult {
export class RateLimitService {
private localRateLimitTracker: Map<string, RateLimitTracker> = new Map();
private localMessageTypeRateLimitTracker: Map<string, RateLimitTracker> = new Map();
private localMessageTypeRateLimitTracker: Map<string, RateLimitTracker> =
new Map();
private cleanupInterval: NodeJS.Timeout | null = null;
private forceSyncInterval: NodeJS.Timeout | null = null;
@@ -68,12 +69,18 @@ export class RateLimitService {
return `ratelimit:${clientId}`;
}
private getMessageTypeRateLimitKey(clientId: string, messageType: string): string {
private getMessageTypeRateLimitKey(
clientId: string,
messageType: string
): string {
return `ratelimit:${clientId}:${messageType}`;
}
// Helper function to clean up old timestamp fields from a Redis hash
private async cleanupOldTimestamps(key: string, windowStart: number): Promise<void> {
private async cleanupOldTimestamps(
key: string,
windowStart: number
): Promise<void> {
if (!redisManager.isRedisEnabled()) return;
try {
@@ -101,10 +108,15 @@ export class RateLimitService {
const batch = fieldsToDelete.slice(i, i + batchSize);
await client.hdel(key, ...batch);
}
logger.debug(`Cleaned up ${fieldsToDelete.length} old timestamp fields from ${key}`);
logger.debug(
`Cleaned up ${fieldsToDelete.length} old timestamp fields from ${key}`
);
}
} catch (error) {
logger.error(`Failed to cleanup old timestamps for key ${key}:`, error);
logger.error(
`Failed to cleanup old timestamps for key ${key}:`,
error
);
// Don't throw - cleanup failures shouldn't block rate limiting
}
}
@@ -114,7 +126,8 @@ export class RateLimitService {
clientId: string,
tracker: RateLimitTracker
): Promise<void> {
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0)
return;
try {
const currentTime = Math.floor(Date.now() / 1000);
@@ -132,7 +145,11 @@ export class RateLimitService {
const newValue = (
parseInt(currentValue || "0") + tracker.pendingCount
).toString();
await redisManager.hset(globalKey, currentTime.toString(), newValue);
await redisManager.hset(
globalKey,
currentTime.toString(),
newValue
);
// Set TTL using the client directly - this prevents the key from persisting forever
if (redisManager.getClient()) {
@@ -145,7 +162,9 @@ export class RateLimitService {
tracker.lastSyncedCount = tracker.count;
tracker.pendingCount = 0;
logger.debug(`Synced global rate limit to Redis for client ${clientId}`);
logger.debug(
`Synced global rate limit to Redis for client ${clientId}`
);
} catch (error) {
logger.error("Failed to sync global rate limit to Redis:", error);
}
@@ -156,12 +175,16 @@ export class RateLimitService {
messageType: string,
tracker: RateLimitTracker
): Promise<void> {
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0)
return;
try {
const currentTime = Math.floor(Date.now() / 1000);
const windowStart = currentTime - RATE_LIMIT_WINDOW;
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
const messageTypeKey = this.getMessageTypeRateLimitKey(
clientId,
messageType
);
// Clean up old timestamp fields before writing
await this.cleanupOldTimestamps(messageTypeKey, windowStart);
@@ -195,12 +218,17 @@ export class RateLimitService {
`Synced message type rate limit to Redis for client ${clientId}, type ${messageType}`
);
} catch (error) {
logger.error("Failed to sync message type rate limit to Redis:", error);
logger.error(
"Failed to sync message type rate limit to Redis:",
error
);
}
}
// Initialize local tracker from Redis data
private async initializeLocalTracker(clientId: string): Promise<RateLimitTracker> {
private async initializeLocalTracker(
clientId: string
): Promise<RateLimitTracker> {
const currentTime = Math.floor(Date.now() / 1000);
const windowStart = currentTime - RATE_LIMIT_WINDOW;
@@ -222,7 +250,9 @@ export class RateLimitService {
const globalRateLimitData = await redisManager.hgetall(globalKey);
let count = 0;
for (const [timestamp, countStr] of Object.entries(globalRateLimitData)) {
for (const [timestamp, countStr] of Object.entries(
globalRateLimitData
)) {
const time = parseInt(timestamp);
if (time >= windowStart) {
count += parseInt(countStr);
@@ -236,7 +266,10 @@ export class RateLimitService {
lastSyncedCount: count
};
} catch (error) {
logger.error("Failed to initialize global tracker from Redis:", error);
logger.error(
"Failed to initialize global tracker from Redis:",
error
);
return {
count: 0,
windowStart: currentTime,
@@ -263,15 +296,21 @@ export class RateLimitService {
}
try {
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
const messageTypeKey = this.getMessageTypeRateLimitKey(
clientId,
messageType
);
// Clean up old timestamp fields before reading
await this.cleanupOldTimestamps(messageTypeKey, windowStart);
const messageTypeRateLimitData = await redisManager.hgetall(messageTypeKey);
const messageTypeRateLimitData =
await redisManager.hgetall(messageTypeKey);
let count = 0;
for (const [timestamp, countStr] of Object.entries(messageTypeRateLimitData)) {
for (const [timestamp, countStr] of Object.entries(
messageTypeRateLimitData
)) {
const time = parseInt(timestamp);
if (time >= windowStart) {
count += parseInt(countStr);
@@ -285,7 +324,10 @@ export class RateLimitService {
lastSyncedCount: count
};
} catch (error) {
logger.error("Failed to initialize message type tracker from Redis:", error);
logger.error(
"Failed to initialize message type tracker from Redis:",
error
);
return {
count: 0,
windowStart: currentTime,
@@ -327,7 +369,10 @@ export class RateLimitService {
isLimited: true,
reason: "global",
totalHits: globalTracker.count,
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
resetTime: new Date(
(globalTracker.windowStart + Math.floor(windowMs / 1000)) *
1000
)
};
}
@@ -339,19 +384,32 @@ export class RateLimitService {
// Check message type specific rate limit if messageType is provided
if (messageType) {
const messageTypeKey = `${clientId}:${messageType}`;
let messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
let messageTypeTracker =
this.localMessageTypeRateLimitTracker.get(messageTypeKey);
if (!messageTypeTracker || messageTypeTracker.windowStart < windowStart) {
if (
!messageTypeTracker ||
messageTypeTracker.windowStart < windowStart
) {
// New window or first request for this message type - initialize from Redis if available
messageTypeTracker = await this.initializeMessageTypeTracker(clientId, messageType);
messageTypeTracker = await this.initializeMessageTypeTracker(
clientId,
messageType
);
messageTypeTracker.windowStart = currentTime;
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
this.localMessageTypeRateLimitTracker.set(
messageTypeKey,
messageTypeTracker
);
}
// Increment message type counters
messageTypeTracker.count++;
messageTypeTracker.pendingCount++;
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
this.localMessageTypeRateLimitTracker.set(
messageTypeKey,
messageTypeTracker
);
// Check if message type limit would be exceeded
if (messageTypeTracker.count >= messageTypeLimit) {
@@ -359,25 +417,38 @@ export class RateLimitService {
isLimited: true,
reason: `message_type:${messageType}`,
totalHits: messageTypeTracker.count,
resetTime: new Date((messageTypeTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
resetTime: new Date(
(messageTypeTracker.windowStart +
Math.floor(windowMs / 1000)) *
1000
)
};
}
// Sync to Redis if threshold reached
if (messageTypeTracker.pendingCount >= REDIS_SYNC_THRESHOLD) {
this.syncMessageTypeRateLimitToRedis(clientId, messageType, messageTypeTracker);
this.syncMessageTypeRateLimitToRedis(
clientId,
messageType,
messageTypeTracker
);
}
}
return {
isLimited: false,
totalHits: globalTracker.count,
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
resetTime: new Date(
(globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000
)
};
}
// Decrement function for skipSuccessfulRequests/skipFailedRequests functionality
async decrementRateLimit(clientId: string, messageType?: string): Promise<void> {
async decrementRateLimit(
clientId: string,
messageType?: string
): Promise<void> {
// Decrement global counter
const globalTracker = this.localRateLimitTracker.get(clientId);
if (globalTracker && globalTracker.count > 0) {
@@ -389,7 +460,8 @@ export class RateLimitService {
// Decrement message type counter if provided
if (messageType) {
const messageTypeKey = `${clientId}:${messageType}`;
const messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
const messageTypeTracker =
this.localMessageTypeRateLimitTracker.get(messageTypeKey);
if (messageTypeTracker && messageTypeTracker.count > 0) {
messageTypeTracker.count--;
messageTypeTracker.pendingCount--;
@@ -417,9 +489,13 @@ export class RateLimitService {
// Get all message type keys for this client and delete them
const client = redisManager.getClient();
if (client) {
const messageTypeKeys = await client.keys(`ratelimit:${clientId}:*`);
const messageTypeKeys = await client.keys(
`ratelimit:${clientId}:*`
);
if (messageTypeKeys.length > 0) {
await Promise.all(messageTypeKeys.map(key => redisManager.del(key)));
await Promise.all(
messageTypeKeys.map((key) => redisManager.del(key))
);
}
}
}
@@ -431,7 +507,10 @@ export class RateLimitService {
const windowStart = currentTime - RATE_LIMIT_WINDOW;
// Clean up global rate limit tracking and sync pending data
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
for (const [
clientId,
tracker
] of this.localRateLimitTracker.entries()) {
if (tracker.windowStart < windowStart) {
// Sync any pending data before cleanup
if (tracker.pendingCount > 0) {
@@ -442,12 +521,19 @@ export class RateLimitService {
}
// Clean up message type rate limit tracking and sync pending data
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
for (const [
key,
tracker
] of this.localMessageTypeRateLimitTracker.entries()) {
if (tracker.windowStart < windowStart) {
// Sync any pending data before cleanup
if (tracker.pendingCount > 0) {
const [clientId, messageType] = key.split(":", 2);
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
await this.syncMessageTypeRateLimitToRedis(
clientId,
messageType,
tracker
);
}
this.localMessageTypeRateLimitTracker.delete(key);
}
@@ -461,17 +547,27 @@ export class RateLimitService {
logger.debug("Force syncing all pending rate limit data to Redis...");
// Sync all pending global rate limits
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
for (const [
clientId,
tracker
] of this.localRateLimitTracker.entries()) {
if (tracker.pendingCount > 0) {
await this.syncRateLimitToRedis(clientId, tracker);
}
}
// Sync all pending message type rate limits
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
for (const [
key,
tracker
] of this.localMessageTypeRateLimitTracker.entries()) {
if (tracker.pendingCount > 0) {
const [clientId, messageType] = key.split(":", 2);
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
await this.syncMessageTypeRateLimitToRedis(
clientId,
messageType,
tracker
);
}
}

View File

@@ -17,7 +17,10 @@ import { MemoryStore, Store } from "express-rate-limit";
import RedisStore from "#private/lib/redisStore";
export function createStore(): Store {
if (build != "oss" && privateConfig.getRawPrivateConfig().flags.enable_redis) {
if (
build != "oss" &&
privateConfig.getRawPrivateConfig().flags.enable_redis
) {
const rateLimitStore: Store = new RedisStore({
prefix: "api-rate-limit", // Optional: customize Redis key prefix
skipFailedRequests: true, // Don't count failed requests

View File

@@ -46,7 +46,8 @@ class RedisManager {
this.isEnabled = false;
return;
}
this.isEnabled = privateConfig.getRawPrivateConfig().flags.enable_redis || false;
this.isEnabled =
privateConfig.getRawPrivateConfig().flags.enable_redis || false;
if (this.isEnabled) {
this.initializeClients();
}
@@ -63,15 +64,19 @@ class RedisManager {
}
private async triggerReconnectionCallbacks(): Promise<void> {
logger.info(`Triggering ${this.reconnectionCallbacks.size} reconnection callbacks`);
logger.info(
`Triggering ${this.reconnectionCallbacks.size} reconnection callbacks`
);
const promises = Array.from(this.reconnectionCallbacks).map(async (callback) => {
const promises = Array.from(this.reconnectionCallbacks).map(
async (callback) => {
try {
await callback();
} catch (error) {
logger.error("Error in reconnection callback:", error);
}
});
}
);
await Promise.allSettled(promises);
}
@@ -79,13 +84,17 @@ class RedisManager {
private async resubscribeToChannels(): Promise<void> {
if (!this.subscriber || this.subscribers.size === 0) return;
logger.info(`Re-subscribing to ${this.subscribers.size} channels after Redis reconnection`);
logger.info(
`Re-subscribing to ${this.subscribers.size} channels after Redis reconnection`
);
try {
const channels = Array.from(this.subscribers.keys());
if (channels.length > 0) {
await this.subscriber.subscribe(...channels);
logger.info(`Successfully re-subscribed to channels: ${channels.join(', ')}`);
logger.info(
`Successfully re-subscribed to channels: ${channels.join(", ")}`
);
}
} catch (error) {
logger.error("Failed to re-subscribe to channels:", error);
@@ -98,7 +107,7 @@ class RedisManager {
host: redisConfig.host!,
port: redisConfig.port!,
password: redisConfig.password,
db: redisConfig.db,
db: redisConfig.db
// tls: {
// rejectUnauthorized:
// redisConfig.tls?.reject_unauthorized || false
@@ -120,7 +129,7 @@ class RedisManager {
host: replica.host!,
port: replica.port!,
password: replica.password,
db: replica.db || redisConfig.db,
db: replica.db || redisConfig.db
// tls: {
// rejectUnauthorized:
// replica.tls?.reject_unauthorized || false
@@ -144,7 +153,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
// Initialize replica connection for reads (if available)
@@ -155,7 +164,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
} else {
// Fallback to master for reads if no replicas
@@ -172,7 +181,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
// Subscriber uses replica if available (reads)
@@ -182,7 +191,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
// Add reconnection handlers for write client
@@ -205,8 +214,11 @@ class RedisManager {
// Trigger reconnection callbacks when Redis comes back online
if (this.isHealthy) {
this.triggerReconnectionCallbacks().catch(error => {
logger.error("Error triggering reconnection callbacks:", error);
this.triggerReconnectionCallbacks().catch((error) => {
logger.error(
"Error triggering reconnection callbacks:",
error
);
});
}
});
@@ -236,8 +248,11 @@ class RedisManager {
// Trigger reconnection callbacks when Redis comes back online
if (this.isHealthy) {
this.triggerReconnectionCallbacks().catch(error => {
logger.error("Error triggering reconnection callbacks:", error);
this.triggerReconnectionCallbacks().catch((error) => {
logger.error(
"Error triggering reconnection callbacks:",
error
);
});
}
});
@@ -313,7 +328,8 @@ class RedisManager {
private updateOverallHealth(): void {
// Overall health is true if write is healthy and (read is healthy OR we don't have replicas)
this.isHealthy = this.isWriteHealthy && (this.isReadHealthy || !this.hasReplicas);
this.isHealthy =
this.isWriteHealthy && (this.isReadHealthy || !this.hasReplicas);
}
private async executeWithRetry<T>(
@@ -332,10 +348,15 @@ class RedisManager {
// If this is the last attempt, try fallback if available
if (attempt === this.maxRetries && fallbackOperation) {
try {
logger.warn(`${operationName} primary operation failed, trying fallback`);
logger.warn(
`${operationName} primary operation failed, trying fallback`
);
return await fallbackOperation();
} catch (fallbackError) {
logger.error(`${operationName} fallback also failed:`, fallbackError);
logger.error(
`${operationName} fallback also failed:`,
fallbackError
);
throw lastError;
}
}
@@ -347,18 +368,25 @@ class RedisManager {
// Calculate delay with exponential backoff
const delay = Math.min(
this.baseRetryDelay * Math.pow(this.backoffMultiplier, attempt),
this.baseRetryDelay *
Math.pow(this.backoffMultiplier, attempt),
this.maxRetryDelay
);
logger.warn(`${operationName} failed (attempt ${attempt + 1}/${this.maxRetries + 1}), retrying in ${delay}ms:`, error);
logger.warn(
`${operationName} failed (attempt ${attempt + 1}/${this.maxRetries + 1}), retrying in ${delay}ms:`,
error
);
// Wait before retrying
await new Promise(resolve => setTimeout(resolve, delay));
await new Promise((resolve) => setTimeout(resolve, delay));
}
}
logger.error(`${operationName} failed after ${this.maxRetries + 1} attempts:`, lastError);
logger.error(
`${operationName} failed after ${this.maxRetries + 1} attempts:`,
lastError
);
throw lastError;
}
@@ -401,23 +429,44 @@ class RedisManager {
await Promise.race([
this.writeClient.ping(),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Write client health check timeout')), 2000)
setTimeout(
() =>
reject(
new Error("Write client health check timeout")
),
2000
)
)
]);
this.isWriteHealthy = true;
// Check read client health if it's different from write client
if (this.hasReplicas && this.readClient && this.readClient !== this.writeClient) {
if (
this.hasReplicas &&
this.readClient &&
this.readClient !== this.writeClient
) {
try {
await Promise.race([
this.readClient.ping(),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Read client health check timeout')), 2000)
setTimeout(
() =>
reject(
new Error(
"Read client health check timeout"
)
),
2000
)
)
]);
this.isReadHealthy = true;
} catch (error) {
logger.error("Redis read client health check failed:", error);
logger.error(
"Redis read client health check failed:",
error
);
this.isReadHealthy = false;
}
} else {
@@ -475,16 +524,13 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.writeClient) return false;
try {
await this.executeWithRetry(
async () => {
await this.executeWithRetry(async () => {
if (ttl) {
await this.writeClient!.setex(key, ttl, value);
} else {
await this.writeClient!.set(key, value);
}
},
"Redis SET"
);
}, "Redis SET");
return true;
} catch (error) {
logger.error("Redis SET error:", error);
@@ -496,7 +542,8 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return null;
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.get(key)
: undefined;
@@ -560,7 +607,8 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return [];
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.smembers(key)
: undefined;
@@ -598,7 +646,8 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return null;
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.hget(key, field)
: undefined;
@@ -632,7 +681,8 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return {};
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.hgetall(key)
: undefined;
@@ -658,18 +708,18 @@ class RedisManager {
}
try {
await this.executeWithRetry(
async () => {
await this.executeWithRetry(async () => {
// Add timeout to prevent hanging
return Promise.race([
this.publisher!.publish(channel, message),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Redis publish timeout')), 3000)
setTimeout(
() => reject(new Error("Redis publish timeout")),
3000
)
)
]);
},
"Redis PUBLISH"
);
}, "Redis PUBLISH");
return true;
} catch (error) {
logger.error("Redis PUBLISH error:", error);
@@ -689,17 +739,20 @@ class RedisManager {
if (!this.subscribers.has(channel)) {
this.subscribers.set(channel, new Set());
// Only subscribe to the channel if it's the first subscriber
await this.executeWithRetry(
async () => {
await this.executeWithRetry(async () => {
return Promise.race([
this.subscriber!.subscribe(channel),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Redis subscribe timeout')), 5000)
setTimeout(
() =>
reject(
new Error("Redis subscribe timeout")
),
5000
)
)
]);
},
"Redis SUBSCRIBE"
);
}, "Redis SUBSCRIBE");
}
this.subscribers.get(channel)!.add(callback);

View File

@@ -11,9 +11,9 @@
* This file is not licensed under the AGPLv3.
*/
import { Store, Options, IncrementResponse } from 'express-rate-limit';
import { rateLimitService } from './rateLimit';
import logger from '@server/logger';
import { Store, Options, IncrementResponse } from "express-rate-limit";
import { rateLimitService } from "./rateLimit";
import logger from "@server/logger";
/**
* A Redis-backed rate limiting store for express-rate-limit that optimizes
@@ -57,12 +57,14 @@ export default class RedisStore implements Store {
*
* @param options - Configuration options for the store.
*/
constructor(options: {
constructor(
options: {
prefix?: string;
skipFailedRequests?: boolean;
skipSuccessfulRequests?: boolean;
} = {}) {
this.prefix = options.prefix || 'express-rate-limit';
} = {}
) {
this.prefix = options.prefix || "express-rate-limit";
this.skipFailedRequests = options.skipFailedRequests || false;
this.skipSuccessfulRequests = options.skipSuccessfulRequests || false;
}
@@ -101,7 +103,8 @@ export default class RedisStore implements Store {
return {
totalHits: result.totalHits || 1,
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
resetTime:
result.resetTime || new Date(Date.now() + this.windowMs)
};
} catch (error) {
logger.error(`RedisStore increment error for key ${key}:`, error);
@@ -158,7 +161,9 @@ export default class RedisStore implements Store {
*/
async resetAll(): Promise<void> {
try {
logger.warn('RedisStore resetAll called - this operation can be expensive');
logger.warn(
"RedisStore resetAll called - this operation can be expensive"
);
// Force sync all pending data first
await rateLimitService.forceSyncAllPendingData();
@@ -167,9 +172,9 @@ export default class RedisStore implements Store {
// scanning all Redis keys with our prefix, which could be expensive.
// In production, it's better to let entries expire naturally.
logger.info('RedisStore resetAll completed (pending data synced)');
logger.info("RedisStore resetAll completed (pending data synced)");
} catch (error) {
logger.error('RedisStore resetAll error:', error);
logger.error("RedisStore resetAll error:", error);
// Don't throw - this is an optional method
}
}
@@ -181,7 +186,9 @@ export default class RedisStore implements Store {
* @param key - The identifier for a client.
* @returns Current hit count and reset time, or null if no data exists.
*/
async getHits(key: string): Promise<{ totalHits: number; resetTime: Date } | null> {
async getHits(
key: string
): Promise<{ totalHits: number; resetTime: Date } | null> {
try {
const clientId = `${this.prefix}:${key}`;
@@ -200,7 +207,8 @@ export default class RedisStore implements Store {
return {
totalHits: Math.max(0, (result.totalHits || 0) - 1), // Adjust for the decrement
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
resetTime:
result.resetTime || new Date(Date.now() + this.windowMs)
};
} catch (error) {
logger.error(`RedisStore getHits error for key ${key}:`, error);
@@ -215,9 +223,9 @@ export default class RedisStore implements Store {
async shutdown(): Promise<void> {
try {
// The rateLimitService handles its own cleanup
logger.info('RedisStore shutdown completed');
logger.info("RedisStore shutdown completed");
} catch (error) {
logger.error('RedisStore shutdown error:', error);
logger.error("RedisStore shutdown error:", error);
}
}
}

View File

@@ -33,7 +33,9 @@ export async function moveEmailToAudience(
audienceId: AudienceIds
) {
if (process.env.ENVIRONMENT !== "prod") {
logger.debug(`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`);
logger.debug(
`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`
);
return;
}
const { error, data } = await retryWithBackoff(async () => {

View File

@@ -19,10 +19,7 @@ import * as crypto from "crypto";
* @param publicKey - The public key used for verification (PEM format)
* @returns The decoded payload if validation succeeds, throws an error otherwise
*/
function validateJWT<Payload>(
token: string,
publicKey: string
): Payload {
function validateJWT<Payload>(token: string, publicKey: string): Payload {
// Split the JWT into its three parts
const parts = token.split(".");
if (parts.length !== 3) {

View File

@@ -41,7 +41,11 @@ async function getActionDays(orgId: string): Promise<number> {
}
// store the result in cache
cache.set(`org_${orgId}_actionDays`, org.settingsLogRetentionDaysAction, 300);
cache.set(
`org_${orgId}_actionDays`,
org.settingsLogRetentionDaysAction,
300
);
return org.settingsLogRetentionDaysAction;
}
@@ -141,4 +145,3 @@ export function logActionAudit(action: ActionsEnum) {
}
};
}

View File

@@ -28,7 +28,8 @@ export async function verifyCertificateAccess(
try {
// Assume user/org access is already verified
const orgId = req.params.orgId;
const certId = req.params.certId || req.body?.certId || req.query?.certId;
const certId =
req.params.certId || req.body?.certId || req.query?.certId;
let domainId =
req.params.domainId || req.body?.domainId || req.query?.domainId;
@@ -39,10 +40,12 @@ export async function verifyCertificateAccess(
}
if (!domainId) {
if (!certId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Must provide either certId or domainId")
createHttpError(
HttpCode.BAD_REQUEST,
"Must provide either certId or domainId"
)
);
}
@@ -75,7 +78,10 @@ export async function verifyCertificateAccess(
if (!domainId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Must provide either certId or domainId")
createHttpError(
HttpCode.BAD_REQUEST,
"Must provide either certId or domainId"
)
);
}

View File

@@ -24,8 +24,7 @@ export async function verifyIdpAccess(
) {
try {
const userId = req.user!.userId;
const idpId =
req.params.idpId || req.body.idpId || req.query.idpId;
const idpId = req.params.idpId || req.body.idpId || req.query.idpId;
const orgId = req.params.orgId;
if (!userId) {
@@ -50,9 +49,7 @@ export async function verifyIdpAccess(
.select()
.from(idp)
.innerJoin(idpOrg, eq(idp.idpId, idpOrg.idpId))
.where(
and(eq(idp.idpId, idpId), eq(idpOrg.orgId, orgId))
)
.where(and(eq(idp.idpId, idpId), eq(idpOrg.orgId, orgId)))
.limit(1);
if (!idpRes || !idpRes.idp || !idpRes.idpOrg) {

View File

@@ -26,7 +26,8 @@ export const verifySessionRemoteExitNodeMiddleware = async (
// get the token from the auth header
const token = req.headers["authorization"]?.split(" ")[1] || "";
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
const { session, remoteExitNode } =
await validateRemoteExitNodeSessionToken(token);
if (!session || !remoteExitNode) {
if (config.getRawConfig().app.log_failed_attempts) {

View File

@@ -19,7 +19,11 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { queryAccessAuditLogsParams, queryAccessAuditLogsQuery, queryAccess } from "./queryAccessAuditLog";
import {
queryAccessAuditLogsParams,
queryAccessAuditLogsQuery,
queryAccess
} from "./queryAccessAuditLog";
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
registry.registerPath({
@@ -68,8 +72,11 @@ export async function exportAccessAuditLogs(
const csvData = generateCSV(log);
res.setHeader('Content-Type', 'text/csv');
res.setHeader('Content-Disposition', `attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`);
res.setHeader("Content-Type", "text/csv");
res.setHeader(
"Content-Disposition",
`attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`
);
return res.send(csvData);
} catch (error) {

View File

@@ -19,7 +19,11 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { queryActionAuditLogsParams, queryActionAuditLogsQuery, queryAction } from "./queryActionAuditLog";
import {
queryActionAuditLogsParams,
queryActionAuditLogsQuery,
queryAction
} from "./queryActionAuditLog";
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
registry.registerPath({
@@ -68,8 +72,11 @@ export async function exportActionAuditLogs(
const csvData = generateCSV(log);
res.setHeader('Content-Type', 'text/csv');
res.setHeader('Content-Disposition', `attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`);
res.setHeader("Content-Type", "text/csv");
res.setHeader(
"Content-Disposition",
`attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`
);
return res.send(csvData);
} catch (error) {

View File

@@ -44,7 +44,8 @@ export const queryAccessAuditLogsQuery = z.object({
.openapi({
type: "string",
format: "date-time",
description: "End time as ISO date string (defaults to current time)"
description:
"End time as ISO date string (defaults to current time)"
}),
action: z
.union([z.boolean(), z.string()])
@@ -181,9 +182,15 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
return {
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
resources: uniqueResources.filter((row): row is { id: number; name: string | null } => row.id !== null),
locations: uniqueLocations.map(row => row.locations).filter((location): location is string => location !== null)
actors: uniqueActors
.map((row) => row.actor)
.filter((actor): actor is string => actor !== null),
resources: uniqueResources.filter(
(row): row is { id: number; name: string | null } => row.id !== null
),
locations: uniqueLocations
.map((row) => row.locations)
.filter((location): location is string => location !== null)
};
}

View File

@@ -44,7 +44,8 @@ export const queryActionAuditLogsQuery = z.object({
.openapi({
type: "string",
format: "date-time",
description: "End time as ISO date string (defaults to current time)"
description:
"End time as ISO date string (defaults to current time)"
}),
action: z.string().optional(),
actorType: z.string().optional(),
@@ -68,8 +69,9 @@ export const queryActionAuditLogsParams = z.object({
orgId: z.string()
});
export const queryActionAuditLogsCombined =
queryActionAuditLogsQuery.merge(queryActionAuditLogsParams);
export const queryActionAuditLogsCombined = queryActionAuditLogsQuery.merge(
queryActionAuditLogsParams
);
type Q = z.infer<typeof queryActionAuditLogsCombined>;
function getWhere(data: Q) {
@@ -78,7 +80,9 @@ function getWhere(data: Q) {
lt(actionAuditLog.timestamp, data.timeEnd),
eq(actionAuditLog.orgId, data.orgId),
data.actor ? eq(actionAuditLog.actor, data.actor) : undefined,
data.actorType ? eq(actionAuditLog.actorType, data.actorType) : undefined,
data.actorType
? eq(actionAuditLog.actorType, data.actorType)
: undefined,
data.actorId ? eq(actionAuditLog.actorId, data.actorId) : undefined,
data.action ? eq(actionAuditLog.action, data.action) : undefined
);
@@ -135,8 +139,12 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
return {
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
actions: uniqueActions.map(row => row.action).filter((action): action is string => action !== null),
actors: uniqueActors
.map((row) => row.actor)
.filter((actor): actor is string => actor !== null),
actions: uniqueActions
.map((row) => row.action)
.filter((action): action is string => action !== null)
};
}

View File

@@ -395,7 +395,8 @@ export async function quickStart(
.values({
targetId: newTarget[0].targetId,
hcEnabled: false
}).returning();
})
.returning();
// add the new target to the targetIps array
targetIps.push(`${ip}/32`);
@@ -406,7 +407,12 @@ export async function quickStart(
.where(eq(newts.siteId, siteId!))
.limit(1);
await addTargets(newt.newtId, newTarget, newHealthcheck, resource.protocol);
await addTargets(
newt.newtId,
newTarget,
newHealthcheck,
resource.protocol
);
// Set resource pincode if provided
if (pincode) {

View File

@@ -78,11 +78,23 @@ export async function getOrgUsage(
// Get usage for org
const usageData = [];
const siteUptime = await usageService.getUsage(orgId, FeatureId.SITE_UPTIME);
const siteUptime = await usageService.getUsage(
orgId,
FeatureId.SITE_UPTIME
);
const users = await usageService.getUsageDaily(orgId, FeatureId.USERS);
const domains = await usageService.getUsageDaily(orgId, FeatureId.DOMAINS);
const remoteExitNodes = await usageService.getUsageDaily(orgId, FeatureId.REMOTE_EXIT_NODES);
const egressData = await usageService.getUsage(orgId, FeatureId.EGRESS_DATA_MB);
const domains = await usageService.getUsageDaily(
orgId,
FeatureId.DOMAINS
);
const remoteExitNodes = await usageService.getUsageDaily(
orgId,
FeatureId.REMOTE_EXIT_NODES
);
const egressData = await usageService.getUsage(
orgId,
FeatureId.EGRESS_DATA_MB
);
if (siteUptime) {
usageData.push(siteUptime);
@@ -100,7 +112,8 @@ export async function getOrgUsage(
usageData.push(remoteExitNodes);
}
const orgLimits = await db.select()
const orgLimits = await db
.select()
.from(limits)
.where(eq(limits.orgId, orgId));

View File

@@ -31,9 +31,7 @@ export async function handleCustomerDeleted(
return;
}
await db
.delete(customers)
.where(eq(customers.customerId, customer.id));
await db.delete(customers).where(eq(customers.customerId, customer.id));
} catch (error) {
logger.error(
`Error handling customer created event for ID ${customer.id}:`,

View File

@@ -12,7 +12,14 @@
*/
import Stripe from "stripe";
import { subscriptions, db, subscriptionItems, customers, userOrgs, users } from "@server/db";
import {
subscriptions,
db,
subscriptionItems,
customers,
userOrgs,
users
} from "@server/db";
import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
@@ -43,7 +50,6 @@ export async function handleSubscriptionDeleted(
.delete(subscriptionItems)
.where(eq(subscriptionItems.subscriptionId, subscription.id));
// Lookup customer to get orgId
const [customer] = await db
.select()
@@ -58,10 +64,7 @@ export async function handleSubscriptionDeleted(
return;
}
await handleSubscriptionLifesycle(
customer.orgId,
subscription.status
);
await handleSubscriptionLifesycle(customer.orgId, subscription.status);
const [orgUserRes] = await db
.select()

View File

@@ -11,11 +11,18 @@
* This file is not licensed under the AGPLv3.
*/
import { freeLimitSet, limitsService, subscribedLimitSet } from "@server/lib/billing";
import {
freeLimitSet,
limitsService,
subscribedLimitSet
} from "@server/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import logger from "@server/logger";
export async function handleSubscriptionLifesycle(orgId: string, status: string) {
export async function handleSubscriptionLifesycle(
orgId: string,
status: string
) {
switch (status) {
case "active":
await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet);

View File

@@ -32,12 +32,13 @@ export async function billingWebhookHandler(
next: NextFunction
): Promise<any> {
let event: Stripe.Event = req.body;
const endpointSecret = privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
const endpointSecret =
privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
if (!endpointSecret) {
logger.warn("Stripe webhook secret is not configured. Webhook events will not be priocessed.");
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "")
logger.warn(
"Stripe webhook secret is not configured. Webhook events will not be priocessed."
);
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, ""));
}
// Only verify the event if you have an endpoint secret defined.
@@ -49,7 +50,10 @@ export async function billingWebhookHandler(
if (!signature) {
logger.info("No stripe signature found in headers.");
return next(
createHttpError(HttpCode.BAD_REQUEST, "No stripe signature found in headers")
createHttpError(
HttpCode.BAD_REQUEST,
"No stripe signature found in headers"
)
);
}
@@ -62,7 +66,10 @@ export async function billingWebhookHandler(
} catch (err) {
logger.error(`Webhook signature verification failed.`, err);
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Webhook signature verification failed")
createHttpError(
HttpCode.UNAUTHORIZED,
"Webhook signature verification failed"
)
);
}
}

View File

@@ -42,7 +42,7 @@ async function query(domainId: string, domain: string) {
let existing: any[] = [];
if (domainRecord.type == "ns") {
const domainLevelDown = domain.split('.').slice(1).join('.');
const domainLevelDown = domain.split(".").slice(1).join(".");
existing = await db
.select({
@@ -64,7 +64,7 @@ async function query(domainId: string, domain: string) {
eq(certificates.wildcard, true), // only NS domains can have wildcard certs
or(
eq(certificates.domain, domain),
eq(certificates.domain, domainLevelDown),
eq(certificates.domain, domainLevelDown)
)
)
);
@@ -102,8 +102,7 @@ registry.registerPath({
tags: ["Certificate"],
request: {
params: z.object({
domainId: z
.string(),
domainId: z.string(),
domain: z.string().min(1).max(255),
orgId: z.string()
})
@@ -133,7 +132,9 @@ export async function getCertificate(
if (!cert) {
logger.warn(`Certificate not found for domain: ${domainId}`);
return next(createHttpError(HttpCode.NOT_FOUND, "Certificate not found"));
return next(
createHttpError(HttpCode.NOT_FOUND, "Certificate not found")
);
}
return response<GetCertificateResponse>(res, {

View File

@@ -36,10 +36,7 @@ registry.registerPath({
tags: ["Certificate"],
request: {
params: z.object({
certId: z
.string()
.transform(stoi)
.pipe(z.int().positive()),
certId: z.string().transform(stoi).pipe(z.int().positive()),
orgId: z.string()
})
},

View File

@@ -95,16 +95,11 @@ const getRoleResourceAccessParamsSchema = z.strictObject({
roleId: z
.string()
.transform(Number)
.pipe(
z.int().positive("Role ID must be a positive integer")
),
.pipe(z.int().positive("Role ID must be a positive integer")),
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const getUserResourceAccessParamsSchema = z.strictObject({
@@ -112,30 +107,21 @@ const getUserResourceAccessParamsSchema = z.strictObject({
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const getResourceRulesParamsSchema = z.strictObject({
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const validateResourceSessionTokenParamsSchema = z.strictObject({
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const validateResourceSessionTokenBodySchema = z.strictObject({
@@ -1408,8 +1394,16 @@ hybridRouter.post(
);
}
const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } =
parsedParams.data;
const {
olmId,
newtId,
ip,
port,
timestamp,
token,
publicKey,
reachableAt
} = parsedParams.data;
const destinations = await updateAndGenerateEndpointDestinations(
olmId,

View File

@@ -18,7 +18,7 @@ import * as logs from "#private/routers/auditLogs";
import {
verifyApiKeyHasAction,
verifyApiKeyIsRoot,
verifyApiKeyOrgAccess,
verifyApiKeyOrgAccess
} from "@server/middlewares";
import {
verifyValidSubscription,
@@ -26,7 +26,10 @@ import {
} from "#private/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import { unauthenticated as ua, authenticated as a } from "@server/routers/integration";
import {
unauthenticated as ua,
authenticated as a
} from "@server/routers/integration";
import { logActionAudit } from "#private/middlewares";
export const unauthenticated = ua;
@@ -37,7 +40,7 @@ authenticated.post(
verifyApiKeyIsRoot, // We are the only ones who can use root key so its fine
verifyApiKeyHasAction(ActionsEnum.sendUsageNotification),
logActionAudit(ActionsEnum.sendUsageNotification),
org.sendUsageNotification,
org.sendUsageNotification
);
authenticated.delete(
@@ -45,7 +48,7 @@ authenticated.delete(
verifyApiKeyIsRoot,
verifyApiKeyHasAction(ActionsEnum.deleteIdp),
logActionAudit(ActionsEnum.deleteIdp),
orgIdp.deleteOrgIdp,
orgIdp.deleteOrgIdp
);
authenticated.get(

View File

@@ -149,12 +149,20 @@ export async function createLoginPage(
let returned: LoginPage | undefined;
await db.transaction(async (trx) => {
const orgSites = await trx
.select()
.from(sites)
.innerJoin(exitNodes, eq(exitNodes.exitNodeId, sites.exitNodeId))
.where(and(eq(sites.orgId, orgId), eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
.innerJoin(
exitNodes,
eq(exitNodes.exitNodeId, sites.exitNodeId)
)
.where(
and(
eq(sites.orgId, orgId),
eq(exitNodes.type, "gerbil"),
eq(exitNodes.online, true)
)
)
.limit(10);
let exitNodesList = orgSites.map((s) => s.exitNodes);
@@ -163,7 +171,12 @@ export async function createLoginPage(
exitNodesList = await trx
.select()
.from(exitNodes)
.where(and(eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
.where(
and(
eq(exitNodes.type, "gerbil"),
eq(exitNodes.online, true)
)
)
.limit(10);
}

View File

@@ -78,15 +78,11 @@ export async function deleteLoginPage(
// if (!leftoverLinks.length) {
await db
.delete(loginPage)
.where(
eq(loginPage.loginPageId, parsedParams.data.loginPageId)
);
.where(eq(loginPage.loginPageId, parsedParams.data.loginPageId));
await db
.delete(loginPageOrg)
.where(
eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId)
);
.where(eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId));
// }
return response<LoginPage>(res, {

View File

@@ -35,7 +35,8 @@ const paramsSchema = z
})
.strict();
const bodySchema = z.strictObject({
const bodySchema = z
.strictObject({
subdomain: subdomainSchema.nullable().optional(),
domainId: z.string().optional()
})
@@ -182,7 +183,10 @@ export async function updateLoginPage(
}
// update the full domain if it has changed
if (fullDomain && fullDomain !== existingLoginPage?.fullDomain) {
if (
fullDomain &&
fullDomain !== existingLoginPage?.fullDomain
) {
await db
.update(loginPage)
.set({ fullDomain })

View File

@@ -35,10 +35,12 @@ const sendUsageNotificationBodySchema = z.object({
notificationType: z.enum(["approaching_70", "approaching_90", "reached"]),
limitName: z.string(),
currentUsage: z.number(),
usageLimit: z.number(),
usageLimit: z.number()
});
type SendUsageNotificationRequest = z.infer<typeof sendUsageNotificationBodySchema>;
type SendUsageNotificationRequest = z.infer<
typeof sendUsageNotificationBodySchema
>;
export type SendUsageNotificationResponse = {
success: boolean;
@@ -97,17 +99,13 @@ async function getOrgAdmins(orgId: string) {
.where(
and(
eq(userOrgs.orgId, orgId),
or(
eq(userOrgs.isOwner, true),
eq(roles.isAdmin, true)
)
or(eq(userOrgs.isOwner, true), eq(roles.isAdmin, true))
)
);
// Filter to only include users with verified emails
const orgAdmins = admins.filter(admin =>
admin.email &&
admin.email.length > 0
const orgAdmins = admins.filter(
(admin) => admin.email && admin.email.length > 0
);
return orgAdmins;
@@ -119,7 +117,9 @@ export async function sendUsageNotification(
next: NextFunction
): Promise<any> {
try {
const parsedParams = sendUsageNotificationParamsSchema.safeParse(req.params);
const parsedParams = sendUsageNotificationParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
@@ -140,12 +140,8 @@ export async function sendUsageNotification(
}
const { orgId } = parsedParams.data;
const {
notificationType,
limitName,
currentUsage,
usageLimit,
} = parsedBody.data;
const { notificationType, limitName, currentUsage, usageLimit } =
parsedBody.data;
// Verify organization exists
const org = await db
@@ -192,7 +188,10 @@ export async function sendUsageNotification(
let template;
let subject;
if (notificationType === "approaching_70" || notificationType === "approaching_90") {
if (
notificationType === "approaching_70" ||
notificationType === "approaching_90"
) {
template = NotifyUsageLimitApproaching({
email: admin.email,
limitName,
@@ -221,9 +220,14 @@ export async function sendUsageNotification(
emailsSent++;
adminEmails.push(admin.email);
logger.info(`Usage notification sent to admin ${admin.email} for org ${orgId}`);
logger.info(
`Usage notification sent to admin ${admin.email} for org ${orgId}`
);
} catch (emailError) {
logger.error(`Failed to send usage notification to ${admin.email}:`, emailError);
logger.error(
`Failed to send usage notification to ${admin.email}:`,
emailError
);
// Continue with other admins even if one fails
}
}
@@ -239,11 +243,13 @@ export async function sendUsageNotification(
message: `Usage notifications sent to ${emailsSent} administrators`,
status: HttpCode.OK
});
} catch (error) {
logger.error("Error sending usage notifications:", error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to send usage notifications")
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to send usage notifications"
)
);
}
}

View File

@@ -158,7 +158,10 @@ export async function createOrgOidcIdp(
});
});
const redirectUrl = await generateOidcRedirectUrl(idpId as number, orgId);
const redirectUrl = await generateOidcRedirectUrl(
idpId as number,
orgId
);
return response<CreateOrgIdpResponse>(res, {
data: {

View File

@@ -66,12 +66,7 @@ export async function deleteOrgIdp(
.where(eq(idp.idpId, idpId));
if (!existingIdp) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"IdP not found"
)
);
return next(createHttpError(HttpCode.NOT_FOUND, "IdP not found"));
}
// Delete the IDP and its related records in a transaction
@@ -82,14 +77,10 @@ export async function deleteOrgIdp(
.where(eq(idpOidcConfig.idpId, idpId));
// Delete IDP-org mappings
await trx
.delete(idpOrg)
.where(eq(idpOrg.idpId, idpId));
await trx.delete(idpOrg).where(eq(idpOrg.idpId, idpId));
// Delete the IDP itself
await trx
.delete(idp)
.where(eq(idp.idpId, idpId));
await trx.delete(idp).where(eq(idp.idpId, idpId));
});
return response<null>(res, {

View File

@@ -93,7 +93,10 @@ export async function getOrgIdp(
idpRes.idpOidcConfig!.clientId = decrypt(clientId, key);
}
const redirectUrl = await generateOidcRedirectUrl(idpRes.idp.idpId, orgId);
const redirectUrl = await generateOidcRedirectUrl(
idpRes.idp.idpId,
orgId
);
return response<GetOrgIdpResponse>(res, {
data: {

View File

@@ -123,7 +123,10 @@ export async function reGenerateClientSecret(
};
// Don't await this to prevent blocking the response
sendToClient(existingOlms[0].olmId, payload).catch((error) => {
logger.error("Failed to send termination message to olm:", error);
logger.error(
"Failed to send termination message to olm:",
error
);
});
disconnectClient(existingOlms[0].olmId).catch((error) => {
@@ -133,7 +136,7 @@ export async function reGenerateClientSecret(
return response(res, {
data: {
olmId: existingOlms[0].olmId,
olmId: existingOlms[0].olmId
},
success: true,
error: false,

View File

@@ -12,7 +12,14 @@
*/
import { NextFunction, Request, Response } from "express";
import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg, RemoteExitNode } from "@server/db";
import {
db,
exitNodes,
exitNodeOrgs,
ExitNode,
ExitNodeOrg,
RemoteExitNode
} from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { remoteExitNodes } from "@server/db";
@@ -91,14 +98,15 @@ export async function reGenerateExitNodeSecret(
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(existingRemoteExitNode.remoteExitNodeId, payload).catch(
(error) => {
sendToClient(
existingRemoteExitNode.remoteExitNodeId,
payload
).catch((error) => {
logger.error(
"Failed to send termination message to remote exit node:",
error
);
}
);
});
disconnectClient(existingRemoteExitNode.remoteExitNodeId).catch(
(error) => {

View File

@@ -120,15 +120,20 @@ export async function reGenerateSiteSecret(
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(existingNewts[0].newtId, payload).catch((error) => {
sendToClient(existingNewts[0].newtId, payload).catch(
(error) => {
logger.error(
"Failed to send termination message to newt:",
error
);
});
}
);
disconnectClient(existingNewts[0].newtId).catch((error) => {
logger.error("Failed to disconnect newt after re-key:", error);
logger.error(
"Failed to disconnect newt after re-key:",
error
);
});
}

View File

@@ -55,7 +55,8 @@ export async function getRemoteExitNodeToken(
try {
if (token) {
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
const { session, remoteExitNode } =
await validateRemoteExitNodeSessionToken(token);
if (session) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
@@ -103,7 +104,10 @@ export async function getRemoteExitNodeToken(
}
const resToken = generateSessionToken();
await createRemoteExitNodeSession(resToken, existingRemoteExitNode.remoteExitNodeId);
await createRemoteExitNodeSession(
resToken,
existingRemoteExitNode.remoteExitNodeId
);
// logger.debug(`Created RemoteExitNode token response: ${JSON.stringify(resToken)}`);

View File

@@ -33,7 +33,9 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = Math.floor((Date.now() - OFFLINE_THRESHOLD_MS) / 1000);
const twoMinutesAgo = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
);
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
const newlyOfflineNodes = await db
@@ -48,11 +50,13 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
isNull(exitNodes.lastPing)
)
)
).returning();
)
.returning();
// Update the sites to offline if they have not pinged either
const exitNodeIds = newlyOfflineNodes.map(node => node.exitNodeId);
const exitNodeIds = newlyOfflineNodes.map(
(node) => node.exitNodeId
);
const sitesOnNode = await db
.select()
@@ -77,7 +81,6 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
.where(eq(sites.siteId, site.siteId));
}
}
} catch (error) {
logger.error("Error in offline checker interval", { error });
}
@@ -100,7 +103,9 @@ export const stopRemoteExitNodeOfflineChecker = (): void => {
/**
* Handles ping messages from clients and responds with pong
*/
export const handleRemoteExitNodePingMessage: MessageHandler = async (context) => {
export const handleRemoteExitNodePingMessage: MessageHandler = async (
context
) => {
const { message, client: c, sendToClient } = context;
const remoteExitNode = c as RemoteExitNode;
@@ -120,7 +125,7 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
.update(exitNodes)
.set({
lastPing: Math.floor(Date.now() / 1000),
online: true,
online: true
})
.where(eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId));
} catch (error) {
@@ -131,7 +136,7 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
message: {
type: "pong",
data: {
timestamp: new Date().toISOString(),
timestamp: new Date().toISOString()
}
},
broadcast: false,

View File

@@ -29,7 +29,8 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
return;
}
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } = message.data;
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } =
message.data;
if (!remoteExitNodeVersion) {
logger.warn("Remote exit node version not found");
@@ -39,7 +40,10 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
// update the version
await db
.update(remoteExitNodes)
.set({ version: remoteExitNodeVersion, secondaryVersion: remoteExitNodeSecondaryVersion })
.set({
version: remoteExitNodeVersion,
secondaryVersion: remoteExitNodeSecondaryVersion
})
.where(
eq(
remoteExitNodes.remoteExitNodeId,

Some files were not shown because too many files have changed in this diff Show More