mirror of
https://github.com/fosrl/pangolin.git
synced 2026-02-21 04:16:38 +00:00
Chungus 2.0
This commit is contained in:
48
server/private/lib/billing/createCustomer.ts
Normal file
48
server/private/lib/billing/createCustomer.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { customers, db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import stripe from "#private/lib/stripe";
|
||||
import { build } from "@server/build";
|
||||
|
||||
export async function createCustomer(
|
||||
orgId: string,
|
||||
email: string | null | undefined
|
||||
): Promise<string | undefined> {
|
||||
if (build !== "saas") {
|
||||
return;
|
||||
}
|
||||
|
||||
const [customer] = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
let customerId: string;
|
||||
// If we don't have a customer, create one
|
||||
if (!customer) {
|
||||
const newCustomer = await stripe!.customers.create({
|
||||
metadata: {
|
||||
orgId: orgId
|
||||
},
|
||||
email: email || undefined
|
||||
});
|
||||
customerId = newCustomer.id;
|
||||
// It will get inserted into the database by the webhook
|
||||
} else {
|
||||
customerId = customer.customerId;
|
||||
}
|
||||
return customerId;
|
||||
}
|
||||
46
server/private/lib/billing/getOrgTierData.ts
Normal file
46
server/private/lib/billing/getOrgTierData.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { getTierPriceSet } from "@server/lib/billing/tiers";
|
||||
import { getOrgSubscriptionData } from "#private/routers/billing/getOrgSubscription";
|
||||
import { build } from "@server/build";
|
||||
|
||||
export async function getOrgTierData(
|
||||
orgId: string
|
||||
): Promise<{ tier: string | null; active: boolean }> {
|
||||
let tier = null;
|
||||
let active = false;
|
||||
|
||||
if (build !== "saas") {
|
||||
return { tier, active };
|
||||
}
|
||||
|
||||
const { subscription, items } = await getOrgSubscriptionData(orgId);
|
||||
|
||||
if (items && items.length > 0) {
|
||||
const tierPriceSet = getTierPriceSet();
|
||||
// Iterate through tiers in order (earlier keys are higher tiers)
|
||||
for (const [tierId, priceId] of Object.entries(tierPriceSet)) {
|
||||
// Check if any subscription item matches this tier's price ID
|
||||
const matchingItem = items.find((item) => item.priceId === priceId);
|
||||
if (matchingItem) {
|
||||
tier = tierId;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (subscription && subscription.status === "active") {
|
||||
active = true;
|
||||
}
|
||||
return { tier, active };
|
||||
}
|
||||
15
server/private/lib/billing/index.ts
Normal file
15
server/private/lib/billing/index.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./getOrgTierData";
|
||||
export * from "./createCustomer";
|
||||
116
server/private/lib/certificates.ts
Normal file
116
server/private/lib/certificates.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import config from "./config";
|
||||
import { certificates, db } from "@server/db";
|
||||
import { and, eq, isNotNull } from "drizzle-orm";
|
||||
import { decryptData } from "@server/lib/encryption";
|
||||
import * as fs from "fs";
|
||||
|
||||
export async function getValidCertificatesForDomains(
|
||||
domains: Set<string>
|
||||
): Promise<
|
||||
Array<{
|
||||
id: number;
|
||||
domain: string;
|
||||
wildcard: boolean | null;
|
||||
certFile: string | null;
|
||||
keyFile: string | null;
|
||||
expiresAt: number | null;
|
||||
updatedAt?: number | null;
|
||||
}>
|
||||
> {
|
||||
if (domains.size === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const domainArray = Array.from(domains);
|
||||
|
||||
// TODO: add more foreign keys to make this query more efficient - we dont need to keep getting every certificate
|
||||
const validCerts = await db
|
||||
.select({
|
||||
id: certificates.certId,
|
||||
domain: certificates.domain,
|
||||
certFile: certificates.certFile,
|
||||
keyFile: certificates.keyFile,
|
||||
expiresAt: certificates.expiresAt,
|
||||
updatedAt: certificates.updatedAt,
|
||||
wildcard: certificates.wildcard
|
||||
})
|
||||
.from(certificates)
|
||||
.where(
|
||||
and(
|
||||
eq(certificates.status, "valid"),
|
||||
isNotNull(certificates.certFile),
|
||||
isNotNull(certificates.keyFile)
|
||||
)
|
||||
);
|
||||
|
||||
// Filter certificates for the specified domains and if it is a wildcard then you can match on everything up to the first dot
|
||||
const validCertsFiltered = validCerts.filter((cert) => {
|
||||
return (
|
||||
domainArray.includes(cert.domain) ||
|
||||
(cert.wildcard &&
|
||||
domainArray.some((domain) =>
|
||||
domain.endsWith(`.${cert.domain}`)
|
||||
))
|
||||
);
|
||||
});
|
||||
|
||||
const encryptionKeyPath = config.getRawPrivateConfig().server.encryption_key_path;
|
||||
|
||||
if (!fs.existsSync(encryptionKeyPath)) {
|
||||
throw new Error(
|
||||
"Encryption key file not found. Please generate one first."
|
||||
);
|
||||
}
|
||||
|
||||
const encryptionKeyHex = fs.readFileSync(encryptionKeyPath, "utf8").trim();
|
||||
const encryptionKey = Buffer.from(encryptionKeyHex, "hex");
|
||||
|
||||
const validCertsDecrypted = validCertsFiltered.map((cert) => {
|
||||
// Decrypt and save certificate file
|
||||
const decryptedCert = decryptData(
|
||||
cert.certFile!, // is not null from query
|
||||
encryptionKey
|
||||
);
|
||||
|
||||
// Decrypt and save key file
|
||||
const decryptedKey = decryptData(cert.keyFile!, encryptionKey);
|
||||
|
||||
// Return only the certificate data without org information
|
||||
return {
|
||||
...cert,
|
||||
certFile: decryptedCert,
|
||||
keyFile: decryptedKey
|
||||
};
|
||||
});
|
||||
|
||||
return validCertsDecrypted;
|
||||
}
|
||||
|
||||
export async function getValidCertificatesForDomainsHybrid(
|
||||
domains: Set<string>
|
||||
): Promise<
|
||||
Array<{
|
||||
id: number;
|
||||
domain: string;
|
||||
wildcard: boolean | null;
|
||||
certFile: string | null;
|
||||
keyFile: string | null;
|
||||
expiresAt: number | null;
|
||||
updatedAt?: number | null;
|
||||
}>
|
||||
> {
|
||||
return []; // stub
|
||||
}
|
||||
163
server/private/lib/config.ts
Normal file
163
server/private/lib/config.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { z } from "zod";
|
||||
import { __DIRNAME, APP_VERSION } from "@server/lib/consts";
|
||||
import { db } from "@server/db";
|
||||
import { SupporterKey, supporterKey } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { license } from "@server/license/license";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import {
|
||||
privateConfigSchema,
|
||||
readPrivateConfigFile
|
||||
} from "#private/lib/readConfigFile";
|
||||
import { build } from "@server/build";
|
||||
|
||||
export class PrivateConfig {
|
||||
private rawPrivateConfig!: z.infer<typeof privateConfigSchema>;
|
||||
|
||||
supporterData: SupporterKey | null = null;
|
||||
|
||||
supporterHiddenUntil: number | null = null;
|
||||
|
||||
isDev: boolean = process.env.ENVIRONMENT !== "prod";
|
||||
|
||||
constructor() {
|
||||
const privateEnvironment = readPrivateConfigFile();
|
||||
|
||||
const {
|
||||
data: parsedPrivateConfig,
|
||||
success: privateSuccess,
|
||||
error: privateError
|
||||
} = privateConfigSchema.safeParse(privateEnvironment);
|
||||
|
||||
if (!privateSuccess) {
|
||||
const errors = fromError(privateError);
|
||||
throw new Error(`Invalid private configuration file: ${errors}`);
|
||||
}
|
||||
|
||||
if (parsedPrivateConfig.branding?.colors) {
|
||||
process.env.BRANDING_COLORS = JSON.stringify(
|
||||
parsedPrivateConfig.branding?.colors
|
||||
);
|
||||
}
|
||||
|
||||
if (parsedPrivateConfig.branding?.logo?.light_path) {
|
||||
process.env.BRANDING_LOGO_LIGHT_PATH =
|
||||
parsedPrivateConfig.branding?.logo?.light_path;
|
||||
}
|
||||
if (parsedPrivateConfig.branding?.logo?.dark_path) {
|
||||
process.env.BRANDING_LOGO_DARK_PATH =
|
||||
parsedPrivateConfig.branding?.logo?.dark_path || undefined;
|
||||
}
|
||||
|
||||
if (build != "oss") {
|
||||
if (parsedPrivateConfig.branding?.logo?.light_path) {
|
||||
process.env.BRANDING_LOGO_LIGHT_PATH =
|
||||
parsedPrivateConfig.branding?.logo?.light_path;
|
||||
}
|
||||
if (parsedPrivateConfig.branding?.logo?.dark_path) {
|
||||
process.env.BRANDING_LOGO_DARK_PATH =
|
||||
parsedPrivateConfig.branding?.logo?.dark_path || undefined;
|
||||
}
|
||||
|
||||
process.env.BRANDING_LOGO_AUTH_WIDTH = parsedPrivateConfig.branding
|
||||
?.logo?.auth_page?.width
|
||||
? parsedPrivateConfig.branding?.logo?.auth_page?.width.toString()
|
||||
: undefined;
|
||||
process.env.BRANDING_LOGO_AUTH_HEIGHT = parsedPrivateConfig.branding
|
||||
?.logo?.auth_page?.height
|
||||
? parsedPrivateConfig.branding?.logo?.auth_page?.height.toString()
|
||||
: undefined;
|
||||
|
||||
process.env.BRANDING_LOGO_NAVBAR_WIDTH = parsedPrivateConfig
|
||||
.branding?.logo?.navbar?.width
|
||||
? parsedPrivateConfig.branding?.logo?.navbar?.width.toString()
|
||||
: undefined;
|
||||
process.env.BRANDING_LOGO_NAVBAR_HEIGHT = parsedPrivateConfig
|
||||
.branding?.logo?.navbar?.height
|
||||
? parsedPrivateConfig.branding?.logo?.navbar?.height.toString()
|
||||
: undefined;
|
||||
|
||||
process.env.BRANDING_FAVICON_PATH =
|
||||
parsedPrivateConfig.branding?.favicon_path;
|
||||
|
||||
process.env.BRANDING_APP_NAME =
|
||||
parsedPrivateConfig.branding?.app_name || "Pangolin";
|
||||
|
||||
if (parsedPrivateConfig.branding?.footer) {
|
||||
process.env.BRANDING_FOOTER = JSON.stringify(
|
||||
parsedPrivateConfig.branding?.footer
|
||||
);
|
||||
}
|
||||
|
||||
process.env.LOGIN_PAGE_TITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.login_page?.title_text || "";
|
||||
process.env.LOGIN_PAGE_SUBTITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.login_page?.subtitle_text || "";
|
||||
|
||||
process.env.SIGNUP_PAGE_TITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.signup_page?.title_text || "";
|
||||
process.env.SIGNUP_PAGE_SUBTITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.signup_page?.subtitle_text || "";
|
||||
|
||||
process.env.RESOURCE_AUTH_PAGE_HIDE_POWERED_BY =
|
||||
parsedPrivateConfig.branding?.resource_auth_page
|
||||
?.hide_powered_by === true
|
||||
? "true"
|
||||
: "false";
|
||||
process.env.RESOURCE_AUTH_PAGE_SHOW_LOGO =
|
||||
parsedPrivateConfig.branding?.resource_auth_page?.show_logo ===
|
||||
true
|
||||
? "true"
|
||||
: "false";
|
||||
process.env.RESOURCE_AUTH_PAGE_TITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.resource_auth_page?.title_text ||
|
||||
"";
|
||||
process.env.RESOURCE_AUTH_PAGE_SUBTITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.resource_auth_page
|
||||
?.subtitle_text || "";
|
||||
|
||||
if (parsedPrivateConfig.branding?.background_image_path) {
|
||||
process.env.BACKGROUND_IMAGE_PATH =
|
||||
parsedPrivateConfig.branding?.background_image_path;
|
||||
}
|
||||
|
||||
if (parsedPrivateConfig.server.reo_client_id) {
|
||||
process.env.REO_CLIENT_ID =
|
||||
parsedPrivateConfig.server.reo_client_id;
|
||||
}
|
||||
|
||||
if (parsedPrivateConfig.stripe?.s3Bucket) {
|
||||
process.env.S3_BUCKET = parsedPrivateConfig.stripe.s3Bucket;
|
||||
}
|
||||
if (parsedPrivateConfig.stripe?.localFilePath) {
|
||||
process.env.LOCAL_FILE_PATH = parsedPrivateConfig.stripe.localFilePath;
|
||||
}
|
||||
if (parsedPrivateConfig.stripe?.s3Region) {
|
||||
process.env.S3_REGION = parsedPrivateConfig.stripe.s3Region;
|
||||
}
|
||||
}
|
||||
|
||||
this.rawPrivateConfig = parsedPrivateConfig;
|
||||
}
|
||||
|
||||
public getRawPrivateConfig() {
|
||||
return this.rawPrivateConfig;
|
||||
}
|
||||
}
|
||||
|
||||
export const privateConfig = new PrivateConfig();
|
||||
|
||||
export default privateConfig;
|
||||
146
server/private/lib/exitNodes/exitNodeComms.ts
Normal file
146
server/private/lib/exitNodes/exitNodeComms.ts
Normal file
@@ -0,0 +1,146 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import axios from "axios";
|
||||
import logger from "@server/logger";
|
||||
import { db, ExitNode, remoteExitNodes } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { sendToClient } from "#private/routers/ws";
|
||||
import privateConfig from "#private/lib/config";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
interface ExitNodeRequest {
|
||||
remoteType?: string;
|
||||
localPath: string;
|
||||
method?: "POST" | "DELETE" | "GET" | "PUT";
|
||||
data?: any;
|
||||
queryParams?: Record<string, string>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a request to an exit node, handling both remote and local exit nodes
|
||||
* @param exitNode The exit node to send the request to
|
||||
* @param request The request configuration
|
||||
* @returns Promise<any> Response data for local nodes, undefined for remote nodes
|
||||
*/
|
||||
export async function sendToExitNode(
|
||||
exitNode: ExitNode,
|
||||
request: ExitNodeRequest
|
||||
): Promise<any> {
|
||||
if (exitNode.type === "remoteExitNode" && request.remoteType) {
|
||||
const [remoteExitNode] = await db
|
||||
.select()
|
||||
.from(remoteExitNodes)
|
||||
.where(eq(remoteExitNodes.exitNodeId, exitNode.exitNodeId))
|
||||
.limit(1);
|
||||
|
||||
if (!remoteExitNode) {
|
||||
throw new Error(
|
||||
`Remote exit node with ID ${exitNode.exitNodeId} not found`
|
||||
);
|
||||
}
|
||||
|
||||
return sendToClient(remoteExitNode.remoteExitNodeId, {
|
||||
type: request.remoteType,
|
||||
data: request.data
|
||||
});
|
||||
} else {
|
||||
let hostname = exitNode.reachableAt;
|
||||
|
||||
logger.debug(`Exit node details:`, {
|
||||
type: exitNode.type,
|
||||
name: exitNode.name,
|
||||
reachableAt: exitNode.reachableAt,
|
||||
});
|
||||
|
||||
logger.debug(`Configured local exit node name: ${config.getRawConfig().gerbil.exit_node_name}`);
|
||||
|
||||
if (exitNode.name == config.getRawConfig().gerbil.exit_node_name) {
|
||||
hostname = privateConfig.getRawPrivateConfig().gerbil.local_exit_node_reachable_at;
|
||||
}
|
||||
|
||||
if (!hostname) {
|
||||
throw new Error(
|
||||
`Exit node with ID ${exitNode.exitNodeId} is not reachable`
|
||||
);
|
||||
}
|
||||
|
||||
logger.debug(`Sending request to exit node at ${hostname}`, {
|
||||
type: request.remoteType,
|
||||
data: request.data
|
||||
});
|
||||
|
||||
// Handle local exit node with HTTP API
|
||||
const method = request.method || "POST";
|
||||
let url = `${hostname}${request.localPath}`;
|
||||
|
||||
// Add query parameters if provided
|
||||
if (request.queryParams) {
|
||||
const params = new URLSearchParams(request.queryParams);
|
||||
url += `?${params.toString()}`;
|
||||
}
|
||||
|
||||
try {
|
||||
let response;
|
||||
|
||||
switch (method) {
|
||||
case "POST":
|
||||
response = await axios.post(url, request.data, {
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
timeout: 8000
|
||||
});
|
||||
break;
|
||||
case "DELETE":
|
||||
response = await axios.delete(url, {
|
||||
timeout: 8000
|
||||
});
|
||||
break;
|
||||
case "GET":
|
||||
response = await axios.get(url, {
|
||||
timeout: 8000
|
||||
});
|
||||
break;
|
||||
case "PUT":
|
||||
response = await axios.put(url, request.data, {
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
timeout: 8000
|
||||
});
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unsupported HTTP method: ${method}`);
|
||||
}
|
||||
|
||||
logger.debug(`Exit node request successful:`, {
|
||||
method,
|
||||
url,
|
||||
status: response.data.status
|
||||
});
|
||||
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
if (axios.isAxiosError(error)) {
|
||||
logger.error(
|
||||
`Error making ${method} request (can Pangolin see Gerbil HTTP API?) for exit node at ${hostname} (status: ${error.response?.status}): ${error.message}`
|
||||
);
|
||||
} else {
|
||||
logger.error(
|
||||
`Error making ${method} request for exit node at ${hostname}: ${error}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
379
server/private/lib/exitNodes/exitNodes.ts
Normal file
379
server/private/lib/exitNodes/exitNodes.ts
Normal file
@@ -0,0 +1,379 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import {
|
||||
db,
|
||||
exitNodes,
|
||||
exitNodeOrgs,
|
||||
resources,
|
||||
targets,
|
||||
sites,
|
||||
targetHealthCheck
|
||||
} from "@server/db";
|
||||
import logger from "@server/logger";
|
||||
import { ExitNodePingResult } from "@server/routers/newt";
|
||||
import { eq, and, or, ne, isNull } from "drizzle-orm";
|
||||
import axios from "axios";
|
||||
import config from "../config";
|
||||
|
||||
/**
|
||||
* Checks if an exit node is actually online by making HTTP requests to its endpoint/ping
|
||||
* Makes up to 3 attempts in parallel with small delays, returns as soon as one succeeds
|
||||
*/
|
||||
async function checkExitNodeOnlineStatus(
|
||||
endpoint: string | undefined
|
||||
): Promise<boolean> {
|
||||
if (!endpoint || endpoint == "") {
|
||||
// the endpoint can start out as a empty string
|
||||
return false;
|
||||
}
|
||||
|
||||
const maxAttempts = 3;
|
||||
const timeoutMs = 5000; // 5 second timeout per request
|
||||
const delayBetweenAttempts = 100; // 100ms delay between starting each attempt
|
||||
|
||||
// Create promises for all attempts with staggered delays
|
||||
const attemptPromises = Array.from({ length: maxAttempts }, async (_, index) => {
|
||||
const attemptNumber = index + 1;
|
||||
|
||||
// Add delay before each attempt (except the first)
|
||||
if (index > 0) {
|
||||
await new Promise((resolve) => setTimeout(resolve, delayBetweenAttempts * index));
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios.get(`http://${endpoint}/ping`, {
|
||||
timeout: timeoutMs,
|
||||
validateStatus: (status) => status === 200
|
||||
});
|
||||
|
||||
if (response.status === 200) {
|
||||
logger.debug(
|
||||
`Exit node ${endpoint} is online (attempt ${attemptNumber}/${maxAttempts})`
|
||||
);
|
||||
return { success: true, attemptNumber };
|
||||
}
|
||||
return { success: false, attemptNumber, error: 'Non-200 status' };
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : "Unknown error";
|
||||
logger.debug(
|
||||
`Exit node ${endpoint} ping failed (attempt ${attemptNumber}/${maxAttempts}): ${errorMessage}`
|
||||
);
|
||||
return { success: false, attemptNumber, error: errorMessage };
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
// Wait for the first successful response or all to fail
|
||||
const results = await Promise.allSettled(attemptPromises);
|
||||
|
||||
// Check if any attempt succeeded
|
||||
for (const result of results) {
|
||||
if (result.status === 'fulfilled' && result.value.success) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// All attempts failed
|
||||
logger.warn(
|
||||
`Exit node ${endpoint} is offline after ${maxAttempts} parallel attempts`
|
||||
);
|
||||
return false;
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Unexpected error checking exit node ${endpoint}: ${error instanceof Error ? error.message : "Unknown error"}`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function verifyExitNodeOrgAccess(
|
||||
exitNodeId: number,
|
||||
orgId: string
|
||||
) {
|
||||
const [result] = await db
|
||||
.select({
|
||||
exitNode: exitNodes,
|
||||
exitNodeOrgId: exitNodeOrgs.exitNodeId
|
||||
})
|
||||
.from(exitNodes)
|
||||
.leftJoin(
|
||||
exitNodeOrgs,
|
||||
and(
|
||||
eq(exitNodeOrgs.exitNodeId, exitNodes.exitNodeId),
|
||||
eq(exitNodeOrgs.orgId, orgId)
|
||||
)
|
||||
)
|
||||
.where(eq(exitNodes.exitNodeId, exitNodeId));
|
||||
|
||||
if (!result) {
|
||||
return { hasAccess: false, exitNode: null };
|
||||
}
|
||||
|
||||
const { exitNode } = result;
|
||||
|
||||
// If the exit node is type "gerbil", access is allowed
|
||||
if (exitNode.type === "gerbil") {
|
||||
return { hasAccess: true, exitNode };
|
||||
}
|
||||
|
||||
// If the exit node is type "remoteExitNode", check if it has org access
|
||||
if (exitNode.type === "remoteExitNode") {
|
||||
return { hasAccess: !!result.exitNodeOrgId, exitNode };
|
||||
}
|
||||
|
||||
// For any other type, deny access
|
||||
return { hasAccess: false, exitNode };
|
||||
}
|
||||
|
||||
export async function listExitNodes(orgId: string, filterOnline = false, noCloud = false) {
|
||||
const allExitNodes = await db
|
||||
.select({
|
||||
exitNodeId: exitNodes.exitNodeId,
|
||||
name: exitNodes.name,
|
||||
address: exitNodes.address,
|
||||
endpoint: exitNodes.endpoint,
|
||||
publicKey: exitNodes.publicKey,
|
||||
listenPort: exitNodes.listenPort,
|
||||
reachableAt: exitNodes.reachableAt,
|
||||
maxConnections: exitNodes.maxConnections,
|
||||
online: exitNodes.online,
|
||||
lastPing: exitNodes.lastPing,
|
||||
type: exitNodes.type,
|
||||
orgId: exitNodeOrgs.orgId,
|
||||
region: exitNodes.region
|
||||
})
|
||||
.from(exitNodes)
|
||||
.leftJoin(
|
||||
exitNodeOrgs,
|
||||
eq(exitNodes.exitNodeId, exitNodeOrgs.exitNodeId)
|
||||
)
|
||||
.where(
|
||||
or(
|
||||
// Include all exit nodes that are NOT of type remoteExitNode
|
||||
and(
|
||||
eq(exitNodes.type, "gerbil"),
|
||||
or(
|
||||
// only choose nodes that are in the same region
|
||||
eq(exitNodes.region, config.getRawPrivateConfig().app.region),
|
||||
isNull(exitNodes.region) // or for enterprise where region is not set
|
||||
)
|
||||
),
|
||||
// Include remoteExitNode types where the orgId matches the newt's organization
|
||||
and(
|
||||
eq(exitNodes.type, "remoteExitNode"),
|
||||
eq(exitNodeOrgs.orgId, orgId)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
// Filter the nodes. If there are NO remoteExitNodes then do nothing. If there are then remove all of the non-remoteExitNodes
|
||||
if (allExitNodes.length === 0) {
|
||||
logger.warn("No exit nodes found for ping request!");
|
||||
return [];
|
||||
}
|
||||
|
||||
// Enhanced online checking: consider node offline if either DB says offline OR HTTP ping fails
|
||||
const nodesWithRealOnlineStatus = await Promise.all(
|
||||
allExitNodes.map(async (node) => {
|
||||
// If database says it's online, verify with HTTP ping
|
||||
let online: boolean;
|
||||
if (filterOnline && node.type == "remoteExitNode") {
|
||||
try {
|
||||
const isActuallyOnline = await checkExitNodeOnlineStatus(
|
||||
node.endpoint
|
||||
);
|
||||
|
||||
// set the item in the database if it is offline
|
||||
if (isActuallyOnline != node.online) {
|
||||
await db
|
||||
.update(exitNodes)
|
||||
.set({ online: isActuallyOnline })
|
||||
.where(eq(exitNodes.exitNodeId, node.exitNodeId));
|
||||
}
|
||||
online = isActuallyOnline;
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Failed to check online status for exit node ${node.name} (${node.endpoint}): ${error instanceof Error ? error.message : "Unknown error"}`
|
||||
);
|
||||
online = false;
|
||||
}
|
||||
} else {
|
||||
online = node.online;
|
||||
}
|
||||
|
||||
return {
|
||||
...node,
|
||||
online
|
||||
};
|
||||
})
|
||||
);
|
||||
|
||||
const remoteExitNodes = nodesWithRealOnlineStatus.filter(
|
||||
(node) =>
|
||||
node.type === "remoteExitNode" && (!filterOnline || node.online)
|
||||
);
|
||||
const gerbilExitNodes = nodesWithRealOnlineStatus.filter(
|
||||
(node) => node.type === "gerbil" && (!filterOnline || node.online) && !noCloud
|
||||
);
|
||||
|
||||
// THIS PROVIDES THE FALL
|
||||
const exitNodesList =
|
||||
remoteExitNodes.length > 0 ? remoteExitNodes : gerbilExitNodes;
|
||||
|
||||
return exitNodesList;
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects the most suitable exit node from a list of ping results.
|
||||
*
|
||||
* The selection algorithm follows these steps:
|
||||
*
|
||||
* 1. **Filter Invalid Nodes**: Excludes nodes with errors or zero weight.
|
||||
*
|
||||
* 2. **Sort by Latency**: Sorts valid nodes in ascending order of latency.
|
||||
*
|
||||
* 3. **Preferred Selection**:
|
||||
* - If the lowest-latency node has sufficient capacity (≥10% weight),
|
||||
* check if a previously connected node is also acceptable.
|
||||
* - The previously connected node is preferred if its latency is within
|
||||
* 30ms or 15% of the best node’s latency.
|
||||
*
|
||||
* 4. **Fallback to Next Best**:
|
||||
* - If the lowest-latency node is under capacity, find the next node
|
||||
* with acceptable capacity.
|
||||
*
|
||||
* 5. **Final Fallback**:
|
||||
* - If no nodes meet the capacity threshold, fall back to the node
|
||||
* with the highest weight (i.e., most available capacity).
|
||||
*
|
||||
*/
|
||||
export function selectBestExitNode(
|
||||
pingResults: ExitNodePingResult[]
|
||||
): ExitNodePingResult | null {
|
||||
const MIN_CAPACITY_THRESHOLD = 0.1;
|
||||
const LATENCY_TOLERANCE_MS = 30;
|
||||
const LATENCY_TOLERANCE_PERCENT = 0.15;
|
||||
|
||||
// Filter out invalid nodes
|
||||
const validNodes = pingResults.filter((n) => !n.error && n.weight > 0);
|
||||
|
||||
if (validNodes.length === 0) {
|
||||
logger.error("No valid exit nodes available");
|
||||
return null;
|
||||
}
|
||||
|
||||
// Sort by latency (ascending)
|
||||
const sortedNodes = validNodes
|
||||
.slice()
|
||||
.sort((a, b) => a.latencyMs - b.latencyMs);
|
||||
const lowestLatencyNode = sortedNodes[0];
|
||||
|
||||
logger.debug(
|
||||
`Lowest latency node: ${lowestLatencyNode.exitNodeName} (${lowestLatencyNode.latencyMs} ms, weight=${lowestLatencyNode.weight.toFixed(2)})`
|
||||
);
|
||||
|
||||
// If lowest latency node has enough capacity, check if previously connected node is acceptable
|
||||
if (lowestLatencyNode.weight >= MIN_CAPACITY_THRESHOLD) {
|
||||
const previouslyConnectedNode = sortedNodes.find(
|
||||
(n) =>
|
||||
n.wasPreviouslyConnected && n.weight >= MIN_CAPACITY_THRESHOLD
|
||||
);
|
||||
|
||||
if (previouslyConnectedNode) {
|
||||
const latencyDiff =
|
||||
previouslyConnectedNode.latencyMs - lowestLatencyNode.latencyMs;
|
||||
const percentDiff = latencyDiff / lowestLatencyNode.latencyMs;
|
||||
|
||||
if (
|
||||
latencyDiff <= LATENCY_TOLERANCE_MS ||
|
||||
percentDiff <= LATENCY_TOLERANCE_PERCENT
|
||||
) {
|
||||
logger.info(
|
||||
`Sticking with previously connected node: ${previouslyConnectedNode.exitNodeName} ` +
|
||||
`(${previouslyConnectedNode.latencyMs} ms), latency diff = ${latencyDiff.toFixed(1)}ms ` +
|
||||
`/ ${(percentDiff * 100).toFixed(1)}%.`
|
||||
);
|
||||
return previouslyConnectedNode;
|
||||
}
|
||||
}
|
||||
|
||||
return lowestLatencyNode;
|
||||
}
|
||||
|
||||
// Otherwise, find the next node (after the lowest) that has enough capacity
|
||||
for (let i = 1; i < sortedNodes.length; i++) {
|
||||
const node = sortedNodes[i];
|
||||
if (node.weight >= MIN_CAPACITY_THRESHOLD) {
|
||||
logger.info(
|
||||
`Lowest latency node under capacity. Using next best: ${node.exitNodeName} ` +
|
||||
`(${node.latencyMs} ms, weight=${node.weight.toFixed(2)})`
|
||||
);
|
||||
return node;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: pick the highest weight node
|
||||
const fallbackNode = validNodes.reduce((a, b) =>
|
||||
a.weight > b.weight ? a : b
|
||||
);
|
||||
logger.warn(
|
||||
`No nodes with ≥10% weight. Falling back to highest capacity node: ${fallbackNode.exitNodeName}`
|
||||
);
|
||||
return fallbackNode;
|
||||
}
|
||||
|
||||
export async function checkExitNodeOrg(exitNodeId: number, orgId: string) {
|
||||
const [exitNodeOrg] = await db
|
||||
.select()
|
||||
.from(exitNodeOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(exitNodeOrgs.exitNodeId, exitNodeId),
|
||||
eq(exitNodeOrgs.orgId, orgId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!exitNodeOrg) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
export async function resolveExitNodes(hostname: string, publicKey: string) {
|
||||
const resourceExitNodes = await db
|
||||
.select({
|
||||
endpoint: exitNodes.endpoint,
|
||||
publicKey: exitNodes.publicKey,
|
||||
orgId: resources.orgId
|
||||
})
|
||||
.from(resources)
|
||||
.innerJoin(targets, eq(resources.resourceId, targets.resourceId))
|
||||
.leftJoin(
|
||||
targetHealthCheck,
|
||||
eq(targetHealthCheck.targetId, targets.targetId)
|
||||
)
|
||||
.innerJoin(sites, eq(targets.siteId, sites.siteId))
|
||||
.innerJoin(exitNodes, eq(sites.exitNodeId, exitNodes.exitNodeId))
|
||||
.where(
|
||||
and(
|
||||
eq(resources.fullDomain, hostname),
|
||||
ne(exitNodes.publicKey, publicKey),
|
||||
ne(targetHealthCheck.hcHealth, "unhealthy")
|
||||
)
|
||||
);
|
||||
|
||||
return resourceExitNodes;
|
||||
}
|
||||
15
server/private/lib/exitNodes/index.ts
Normal file
15
server/private/lib/exitNodes/index.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./exitNodeComms";
|
||||
export * from "./exitNodes";
|
||||
202
server/private/lib/rateLimit.test.ts
Normal file
202
server/private/lib/rateLimit.test.ts
Normal file
@@ -0,0 +1,202 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
// Simple test file for the rate limit service with Redis
|
||||
// Run with: npx ts-node rateLimitService.test.ts
|
||||
|
||||
import { RateLimitService } from './rateLimit';
|
||||
|
||||
function generateClientId() {
|
||||
return 'client-' + Math.random().toString(36).substring(2, 15);
|
||||
}
|
||||
|
||||
async function runTests() {
|
||||
console.log('Starting Rate Limit Service Tests...\n');
|
||||
|
||||
const rateLimitService = new RateLimitService();
|
||||
let testsPassed = 0;
|
||||
let testsTotal = 0;
|
||||
|
||||
// Helper function to run a test
|
||||
async function test(name: string, testFn: () => Promise<void>) {
|
||||
testsTotal++;
|
||||
try {
|
||||
await testFn();
|
||||
console.log(`✅ ${name}`);
|
||||
testsPassed++;
|
||||
} catch (error) {
|
||||
console.log(`❌ ${name}: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for assertions
|
||||
function assert(condition: boolean, message: string) {
|
||||
if (!condition) {
|
||||
throw new Error(message);
|
||||
}
|
||||
}
|
||||
|
||||
// Test 1: Basic rate limiting
|
||||
await test('Should allow requests under the limit', async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 5;
|
||||
|
||||
for (let i = 0; i < maxRequests - 1; i++) {
|
||||
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
|
||||
assert(result.totalHits === i + 1, `Expected ${i + 1} hits, got ${result.totalHits}`);
|
||||
}
|
||||
});
|
||||
|
||||
// Test 2: Rate limit blocking
|
||||
await test('Should block requests over the limit', async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 30;
|
||||
|
||||
// Use up all allowed requests
|
||||
for (let i = 0; i < maxRequests - 1; i++) {
|
||||
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
|
||||
}
|
||||
|
||||
// Next request should be blocked
|
||||
const blockedResult = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(blockedResult.isLimited, 'Request should be blocked');
|
||||
assert(blockedResult.reason === 'global', 'Should be blocked for global reason');
|
||||
});
|
||||
|
||||
// Test 3: Message type limits
|
||||
await test('Should handle message type limits', async () => {
|
||||
const clientId = generateClientId();
|
||||
const globalMax = 10;
|
||||
const messageTypeMax = 2;
|
||||
|
||||
// Send messages of type 'ping' up to the limit
|
||||
for (let i = 0; i < messageTypeMax - 1; i++) {
|
||||
const result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
'ping',
|
||||
globalMax,
|
||||
messageTypeMax
|
||||
);
|
||||
assert(!result.isLimited, `Ping message ${i + 1} should be allowed`);
|
||||
}
|
||||
|
||||
// Next 'ping' should be blocked
|
||||
const blockedResult = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
'ping',
|
||||
globalMax,
|
||||
messageTypeMax
|
||||
);
|
||||
assert(blockedResult.isLimited, 'Ping message should be blocked');
|
||||
assert(blockedResult.reason === 'message_type:ping', 'Should be blocked for message type');
|
||||
|
||||
// Other message types should still work
|
||||
const otherResult = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
'pong',
|
||||
globalMax,
|
||||
messageTypeMax
|
||||
);
|
||||
assert(!otherResult.isLimited, 'Pong message should be allowed');
|
||||
});
|
||||
|
||||
// Test 4: Reset functionality
|
||||
await test('Should reset client correctly', async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 3;
|
||||
|
||||
// Use up some requests
|
||||
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
await rateLimitService.checkRateLimit(clientId, 'test', maxRequests);
|
||||
|
||||
// Reset the client
|
||||
await rateLimitService.resetKey(clientId);
|
||||
|
||||
// Should be able to make fresh requests
|
||||
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(!result.isLimited, 'Request after reset should be allowed');
|
||||
assert(result.totalHits === 1, 'Should have 1 hit after reset');
|
||||
});
|
||||
|
||||
// Test 5: Different clients are independent
|
||||
await test('Should handle different clients independently', async () => {
|
||||
const client1 = generateClientId();
|
||||
const client2 = generateClientId();
|
||||
const maxRequests = 2;
|
||||
|
||||
// Client 1 uses up their limit
|
||||
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
|
||||
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
|
||||
const client1Blocked = await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
|
||||
assert(client1Blocked.isLimited, 'Client 1 should be blocked');
|
||||
|
||||
// Client 2 should still be able to make requests
|
||||
const client2Result = await rateLimitService.checkRateLimit(client2, undefined, maxRequests);
|
||||
assert(!client2Result.isLimited, 'Client 2 should not be blocked');
|
||||
assert(client2Result.totalHits === 1, 'Client 2 should have 1 hit');
|
||||
});
|
||||
|
||||
// Test 6: Decrement functionality
|
||||
await test('Should decrement correctly', async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 5;
|
||||
|
||||
// Make some requests
|
||||
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
let result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(result.totalHits === 3, 'Should have 3 hits before decrement');
|
||||
|
||||
// Decrement
|
||||
await rateLimitService.decrementRateLimit(clientId);
|
||||
|
||||
// Next request should reflect the decrement
|
||||
result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(result.totalHits === 3, 'Should have 3 hits after decrement + increment');
|
||||
});
|
||||
|
||||
// Wait a moment for any pending Redis operations
|
||||
console.log('\nWaiting for Redis sync...');
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
|
||||
// Force sync to test Redis integration
|
||||
await test('Should sync to Redis', async () => {
|
||||
await rateLimitService.forceSyncAllPendingData();
|
||||
// If this doesn't throw, Redis sync is working
|
||||
assert(true, 'Redis sync completed');
|
||||
});
|
||||
|
||||
// Cleanup
|
||||
await rateLimitService.cleanup();
|
||||
|
||||
// Results
|
||||
console.log(`\n--- Test Results ---`);
|
||||
console.log(`✅ Passed: ${testsPassed}/${testsTotal}`);
|
||||
console.log(`❌ Failed: ${testsTotal - testsPassed}/${testsTotal}`);
|
||||
|
||||
if (testsPassed === testsTotal) {
|
||||
console.log('\n🎉 All tests passed!');
|
||||
process.exit(0);
|
||||
} else {
|
||||
console.log('\n💥 Some tests failed!');
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Run the tests
|
||||
runTests().catch(error => {
|
||||
console.error('Test runner error:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
454
server/private/lib/rateLimit.ts
Normal file
454
server/private/lib/rateLimit.ts
Normal file
@@ -0,0 +1,454 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import logger from "@server/logger";
|
||||
import redisManager from "@server/private/lib/redis";
|
||||
import { build } from "@server/build";
|
||||
|
||||
// Rate limiting configuration
|
||||
export const RATE_LIMIT_WINDOW = 60; // 1 minute in seconds
|
||||
export const RATE_LIMIT_MAX_REQUESTS = 100;
|
||||
export const RATE_LIMIT_PER_MESSAGE_TYPE = 20; // Per message type limit within the window
|
||||
|
||||
// Configuration for batched Redis sync
|
||||
export const REDIS_SYNC_THRESHOLD = 15; // Sync to Redis every N messages
|
||||
export const REDIS_SYNC_FORCE_INTERVAL = 30000; // Force sync every 30 seconds as backup
|
||||
|
||||
interface RateLimitTracker {
|
||||
count: number;
|
||||
windowStart: number;
|
||||
pendingCount: number;
|
||||
lastSyncedCount: number;
|
||||
}
|
||||
|
||||
interface RateLimitResult {
|
||||
isLimited: boolean;
|
||||
reason?: string;
|
||||
totalHits?: number;
|
||||
resetTime?: Date;
|
||||
}
|
||||
|
||||
export class RateLimitService {
|
||||
private localRateLimitTracker: Map<string, RateLimitTracker> = new Map();
|
||||
private localMessageTypeRateLimitTracker: Map<string, RateLimitTracker> = new Map();
|
||||
private cleanupInterval: NodeJS.Timeout | null = null;
|
||||
private forceSyncInterval: NodeJS.Timeout | null = null;
|
||||
|
||||
constructor() {
|
||||
if (build == "oss") {
|
||||
return;
|
||||
}
|
||||
|
||||
// Start cleanup and sync intervals
|
||||
this.cleanupInterval = setInterval(() => {
|
||||
this.cleanupLocalRateLimit().catch((error) => {
|
||||
logger.error("Error during rate limit cleanup:", error);
|
||||
});
|
||||
}, 60000); // Run cleanup every minute
|
||||
|
||||
this.forceSyncInterval = setInterval(() => {
|
||||
this.forceSyncAllPendingData().catch((error) => {
|
||||
logger.error("Error during force sync:", error);
|
||||
});
|
||||
}, REDIS_SYNC_FORCE_INTERVAL);
|
||||
}
|
||||
|
||||
// Redis keys
|
||||
private getRateLimitKey(clientId: string): string {
|
||||
return `ratelimit:${clientId}`;
|
||||
}
|
||||
|
||||
private getMessageTypeRateLimitKey(clientId: string, messageType: string): string {
|
||||
return `ratelimit:${clientId}:${messageType}`;
|
||||
}
|
||||
|
||||
// Helper function to sync local rate limit data to Redis
|
||||
private async syncRateLimitToRedis(
|
||||
clientId: string,
|
||||
tracker: RateLimitTracker
|
||||
): Promise<void> {
|
||||
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
|
||||
|
||||
try {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const globalKey = this.getRateLimitKey(clientId);
|
||||
|
||||
// Get current value and add pending count
|
||||
const currentValue = await redisManager.hget(
|
||||
globalKey,
|
||||
currentTime.toString()
|
||||
);
|
||||
const newValue = (
|
||||
parseInt(currentValue || "0") + tracker.pendingCount
|
||||
).toString();
|
||||
await redisManager.hset(globalKey, currentTime.toString(), newValue);
|
||||
|
||||
// Set TTL using the client directly
|
||||
if (redisManager.getClient()) {
|
||||
await redisManager
|
||||
.getClient()
|
||||
.expire(globalKey, RATE_LIMIT_WINDOW + 10);
|
||||
}
|
||||
|
||||
// Update tracking
|
||||
tracker.lastSyncedCount = tracker.count;
|
||||
tracker.pendingCount = 0;
|
||||
|
||||
logger.debug(`Synced global rate limit to Redis for client ${clientId}`);
|
||||
} catch (error) {
|
||||
logger.error("Failed to sync global rate limit to Redis:", error);
|
||||
}
|
||||
}
|
||||
|
||||
private async syncMessageTypeRateLimitToRedis(
|
||||
clientId: string,
|
||||
messageType: string,
|
||||
tracker: RateLimitTracker
|
||||
): Promise<void> {
|
||||
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
|
||||
|
||||
try {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
|
||||
|
||||
// Get current value and add pending count
|
||||
const currentValue = await redisManager.hget(
|
||||
messageTypeKey,
|
||||
currentTime.toString()
|
||||
);
|
||||
const newValue = (
|
||||
parseInt(currentValue || "0") + tracker.pendingCount
|
||||
).toString();
|
||||
await redisManager.hset(
|
||||
messageTypeKey,
|
||||
currentTime.toString(),
|
||||
newValue
|
||||
);
|
||||
|
||||
// Set TTL using the client directly
|
||||
if (redisManager.getClient()) {
|
||||
await redisManager
|
||||
.getClient()
|
||||
.expire(messageTypeKey, RATE_LIMIT_WINDOW + 10);
|
||||
}
|
||||
|
||||
// Update tracking
|
||||
tracker.lastSyncedCount = tracker.count;
|
||||
tracker.pendingCount = 0;
|
||||
|
||||
logger.debug(
|
||||
`Synced message type rate limit to Redis for client ${clientId}, type ${messageType}`
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to sync message type rate limit to Redis:", error);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize local tracker from Redis data
|
||||
private async initializeLocalTracker(clientId: string): Promise<RateLimitTracker> {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const windowStart = currentTime - RATE_LIMIT_WINDOW;
|
||||
|
||||
if (!redisManager.isRedisEnabled()) {
|
||||
return {
|
||||
count: 0,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: 0
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const globalKey = this.getRateLimitKey(clientId);
|
||||
const globalRateLimitData = await redisManager.hgetall(globalKey);
|
||||
|
||||
let count = 0;
|
||||
for (const [timestamp, countStr] of Object.entries(globalRateLimitData)) {
|
||||
const time = parseInt(timestamp);
|
||||
if (time >= windowStart) {
|
||||
count += parseInt(countStr);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
count,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: count
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error("Failed to initialize global tracker from Redis:", error);
|
||||
return {
|
||||
count: 0,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: 0
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private async initializeMessageTypeTracker(
|
||||
clientId: string,
|
||||
messageType: string
|
||||
): Promise<RateLimitTracker> {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const windowStart = currentTime - RATE_LIMIT_WINDOW;
|
||||
|
||||
if (!redisManager.isRedisEnabled()) {
|
||||
return {
|
||||
count: 0,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: 0
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
|
||||
const messageTypeRateLimitData = await redisManager.hgetall(messageTypeKey);
|
||||
|
||||
let count = 0;
|
||||
for (const [timestamp, countStr] of Object.entries(messageTypeRateLimitData)) {
|
||||
const time = parseInt(timestamp);
|
||||
if (time >= windowStart) {
|
||||
count += parseInt(countStr);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
count,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: count
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error("Failed to initialize message type tracker from Redis:", error);
|
||||
return {
|
||||
count: 0,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: 0
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Main rate limiting function
|
||||
async checkRateLimit(
|
||||
clientId: string,
|
||||
messageType?: string,
|
||||
maxRequests: number = RATE_LIMIT_MAX_REQUESTS,
|
||||
messageTypeLimit: number = RATE_LIMIT_PER_MESSAGE_TYPE,
|
||||
windowMs: number = RATE_LIMIT_WINDOW * 1000
|
||||
): Promise<RateLimitResult> {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const windowStart = currentTime - Math.floor(windowMs / 1000);
|
||||
|
||||
// Check global rate limit
|
||||
let globalTracker = this.localRateLimitTracker.get(clientId);
|
||||
|
||||
if (!globalTracker || globalTracker.windowStart < windowStart) {
|
||||
// New window or first request - initialize from Redis if available
|
||||
globalTracker = await this.initializeLocalTracker(clientId);
|
||||
globalTracker.windowStart = currentTime;
|
||||
this.localRateLimitTracker.set(clientId, globalTracker);
|
||||
}
|
||||
|
||||
// Increment global counters
|
||||
globalTracker.count++;
|
||||
globalTracker.pendingCount++;
|
||||
this.localRateLimitTracker.set(clientId, globalTracker);
|
||||
|
||||
// Check if global limit would be exceeded
|
||||
if (globalTracker.count >= maxRequests) {
|
||||
return {
|
||||
isLimited: true,
|
||||
reason: "global",
|
||||
totalHits: globalTracker.count,
|
||||
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
|
||||
};
|
||||
}
|
||||
|
||||
// Sync to Redis if threshold reached
|
||||
if (globalTracker.pendingCount >= REDIS_SYNC_THRESHOLD) {
|
||||
this.syncRateLimitToRedis(clientId, globalTracker);
|
||||
}
|
||||
|
||||
// Check message type specific rate limit if messageType is provided
|
||||
if (messageType) {
|
||||
const messageTypeKey = `${clientId}:${messageType}`;
|
||||
let messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
|
||||
|
||||
if (!messageTypeTracker || messageTypeTracker.windowStart < windowStart) {
|
||||
// New window or first request for this message type - initialize from Redis if available
|
||||
messageTypeTracker = await this.initializeMessageTypeTracker(clientId, messageType);
|
||||
messageTypeTracker.windowStart = currentTime;
|
||||
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
|
||||
}
|
||||
|
||||
// Increment message type counters
|
||||
messageTypeTracker.count++;
|
||||
messageTypeTracker.pendingCount++;
|
||||
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
|
||||
|
||||
// Check if message type limit would be exceeded
|
||||
if (messageTypeTracker.count >= messageTypeLimit) {
|
||||
return {
|
||||
isLimited: true,
|
||||
reason: `message_type:${messageType}`,
|
||||
totalHits: messageTypeTracker.count,
|
||||
resetTime: new Date((messageTypeTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
|
||||
};
|
||||
}
|
||||
|
||||
// Sync to Redis if threshold reached
|
||||
if (messageTypeTracker.pendingCount >= REDIS_SYNC_THRESHOLD) {
|
||||
this.syncMessageTypeRateLimitToRedis(clientId, messageType, messageTypeTracker);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isLimited: false,
|
||||
totalHits: globalTracker.count,
|
||||
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
|
||||
};
|
||||
}
|
||||
|
||||
// Decrement function for skipSuccessfulRequests/skipFailedRequests functionality
|
||||
async decrementRateLimit(clientId: string, messageType?: string): Promise<void> {
|
||||
// Decrement global counter
|
||||
const globalTracker = this.localRateLimitTracker.get(clientId);
|
||||
if (globalTracker && globalTracker.count > 0) {
|
||||
globalTracker.count--;
|
||||
// We need to account for this in pending count to sync correctly
|
||||
globalTracker.pendingCount--;
|
||||
}
|
||||
|
||||
// Decrement message type counter if provided
|
||||
if (messageType) {
|
||||
const messageTypeKey = `${clientId}:${messageType}`;
|
||||
const messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
|
||||
if (messageTypeTracker && messageTypeTracker.count > 0) {
|
||||
messageTypeTracker.count--;
|
||||
messageTypeTracker.pendingCount--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset key function
|
||||
async resetKey(clientId: string): Promise<void> {
|
||||
// Remove from local tracking
|
||||
this.localRateLimitTracker.delete(clientId);
|
||||
|
||||
// Remove all message type entries for this client
|
||||
for (const [key] of this.localMessageTypeRateLimitTracker) {
|
||||
if (key.startsWith(`${clientId}:`)) {
|
||||
this.localMessageTypeRateLimitTracker.delete(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from Redis if enabled
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
const globalKey = this.getRateLimitKey(clientId);
|
||||
await redisManager.del(globalKey);
|
||||
|
||||
// Get all message type keys for this client and delete them
|
||||
const client = redisManager.getClient();
|
||||
if (client) {
|
||||
const messageTypeKeys = await client.keys(`ratelimit:${clientId}:*`);
|
||||
if (messageTypeKeys.length > 0) {
|
||||
await Promise.all(messageTypeKeys.map(key => redisManager.del(key)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup old local rate limit entries and force sync pending data
|
||||
private async cleanupLocalRateLimit(): Promise<void> {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const windowStart = currentTime - RATE_LIMIT_WINDOW;
|
||||
|
||||
// Clean up global rate limit tracking and sync pending data
|
||||
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
|
||||
if (tracker.windowStart < windowStart) {
|
||||
// Sync any pending data before cleanup
|
||||
if (tracker.pendingCount > 0) {
|
||||
await this.syncRateLimitToRedis(clientId, tracker);
|
||||
}
|
||||
this.localRateLimitTracker.delete(clientId);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up message type rate limit tracking and sync pending data
|
||||
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
|
||||
if (tracker.windowStart < windowStart) {
|
||||
// Sync any pending data before cleanup
|
||||
if (tracker.pendingCount > 0) {
|
||||
const [clientId, messageType] = key.split(":", 2);
|
||||
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
|
||||
}
|
||||
this.localMessageTypeRateLimitTracker.delete(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Force sync all pending rate limit data to Redis
|
||||
async forceSyncAllPendingData(): Promise<void> {
|
||||
if (!redisManager.isRedisEnabled()) return;
|
||||
|
||||
logger.debug("Force syncing all pending rate limit data to Redis...");
|
||||
|
||||
// Sync all pending global rate limits
|
||||
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
|
||||
if (tracker.pendingCount > 0) {
|
||||
await this.syncRateLimitToRedis(clientId, tracker);
|
||||
}
|
||||
}
|
||||
|
||||
// Sync all pending message type rate limits
|
||||
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
|
||||
if (tracker.pendingCount > 0) {
|
||||
const [clientId, messageType] = key.split(":", 2);
|
||||
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug("Completed force sync of pending rate limit data");
|
||||
}
|
||||
|
||||
// Cleanup function for graceful shutdown
|
||||
async cleanup(): Promise<void> {
|
||||
if (build == "oss") {
|
||||
return;
|
||||
}
|
||||
|
||||
// Clear intervals
|
||||
if (this.cleanupInterval) {
|
||||
clearInterval(this.cleanupInterval);
|
||||
}
|
||||
if (this.forceSyncInterval) {
|
||||
clearInterval(this.forceSyncInterval);
|
||||
}
|
||||
|
||||
// Force sync all pending data
|
||||
await this.forceSyncAllPendingData();
|
||||
|
||||
// Clear local data
|
||||
this.localRateLimitTracker.clear();
|
||||
this.localMessageTypeRateLimitTracker.clear();
|
||||
|
||||
logger.info("Rate limit service cleanup completed");
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const rateLimitService = new RateLimitService();
|
||||
32
server/private/lib/rateLimitStore.ts
Normal file
32
server/private/lib/rateLimitStore.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { build } from "@server/build";
|
||||
import privateConfig from "#private/lib/config";
|
||||
import { MemoryStore, Store } from "express-rate-limit";
|
||||
import RedisStore from "#private/lib/redisStore";
|
||||
|
||||
export function createStore(): Store {
|
||||
if (build != "oss" && privateConfig.getRawPrivateConfig().flags?.enable_redis) {
|
||||
const rateLimitStore: Store = new RedisStore({
|
||||
prefix: "api-rate-limit", // Optional: customize Redis key prefix
|
||||
skipFailedRequests: true, // Don't count failed requests
|
||||
skipSuccessfulRequests: false // Count successful requests
|
||||
});
|
||||
|
||||
return rateLimitStore;
|
||||
} else {
|
||||
const rateLimitStore: Store = new MemoryStore();
|
||||
return rateLimitStore;
|
||||
}
|
||||
}
|
||||
191
server/private/lib/readConfigFile.ts
Normal file
191
server/private/lib/readConfigFile.ts
Normal file
@@ -0,0 +1,191 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import fs from "fs";
|
||||
import yaml from "js-yaml";
|
||||
import { privateConfigFilePath1 } from "@server/lib/consts";
|
||||
import { z } from "zod";
|
||||
import { colorsSchema } from "@server/lib/colorsSchema";
|
||||
import { build } from "@server/build";
|
||||
|
||||
const portSchema = z.number().positive().gt(0).lte(65535);
|
||||
|
||||
export const privateConfigSchema = z
|
||||
.object({
|
||||
app: z.object({
|
||||
region: z.string().optional().default("default"),
|
||||
base_domain: z.string().optional()
|
||||
}).optional().default({
|
||||
region: "default"
|
||||
}),
|
||||
server: z.object({
|
||||
encryption_key_path: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("./config/encryption.pem")
|
||||
.pipe(z.string().min(8)),
|
||||
resend_api_key: z.string().optional(),
|
||||
reo_client_id: z.string().optional(),
|
||||
}).optional().default({
|
||||
encryption_key_path: "./config/encryption.pem"
|
||||
}),
|
||||
redis: z
|
||||
.object({
|
||||
host: z.string(),
|
||||
port: portSchema,
|
||||
password: z.string().optional(),
|
||||
db: z.number().int().nonnegative().optional().default(0),
|
||||
replicas: z
|
||||
.array(
|
||||
z.object({
|
||||
host: z.string(),
|
||||
port: portSchema,
|
||||
password: z.string().optional(),
|
||||
db: z.number().int().nonnegative().optional().default(0)
|
||||
})
|
||||
)
|
||||
.optional()
|
||||
// tls: z
|
||||
// .object({
|
||||
// reject_unauthorized: z
|
||||
// .boolean()
|
||||
// .optional()
|
||||
// .default(true)
|
||||
// })
|
||||
// .optional()
|
||||
})
|
||||
.optional(),
|
||||
gerbil: z
|
||||
.object({
|
||||
local_exit_node_reachable_at: z.string().optional().default("http://gerbil:3003")
|
||||
})
|
||||
.optional()
|
||||
.default({}),
|
||||
flags: z
|
||||
.object({
|
||||
enable_redis: z.boolean().optional(),
|
||||
})
|
||||
.optional(),
|
||||
branding: z
|
||||
.object({
|
||||
app_name: z.string().optional(),
|
||||
background_image_path: z.string().optional(),
|
||||
colors: z
|
||||
.object({
|
||||
light: colorsSchema.optional(),
|
||||
dark: colorsSchema.optional()
|
||||
})
|
||||
.optional(),
|
||||
logo: z
|
||||
.object({
|
||||
light_path: z.string().optional(),
|
||||
dark_path: z.string().optional(),
|
||||
auth_page: z
|
||||
.object({
|
||||
width: z.number().optional(),
|
||||
height: z.number().optional()
|
||||
})
|
||||
.optional(),
|
||||
navbar: z
|
||||
.object({
|
||||
width: z.number().optional(),
|
||||
height: z.number().optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
.optional(),
|
||||
favicon_path: z.string().optional(),
|
||||
footer: z
|
||||
.array(
|
||||
z.object({
|
||||
text: z.string(),
|
||||
href: z.string().optional()
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
login_page: z
|
||||
.object({
|
||||
subtitle_text: z.string().optional(),
|
||||
title_text: z.string().optional()
|
||||
})
|
||||
.optional(),
|
||||
signup_page: z
|
||||
.object({
|
||||
subtitle_text: z.string().optional(),
|
||||
title_text: z.string().optional()
|
||||
})
|
||||
.optional(),
|
||||
resource_auth_page: z
|
||||
.object({
|
||||
show_logo: z.boolean().optional(),
|
||||
hide_powered_by: z.boolean().optional(),
|
||||
title_text: z.string().optional(),
|
||||
subtitle_text: z.string().optional()
|
||||
})
|
||||
.optional(),
|
||||
emails: z
|
||||
.object({
|
||||
signature: z.string().optional(),
|
||||
colors: z
|
||||
.object({
|
||||
primary: z.string().optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
.optional(),
|
||||
stripe: z
|
||||
.object({
|
||||
secret_key: z.string(),
|
||||
webhook_secret: z.string(),
|
||||
s3Bucket: z.string(),
|
||||
s3Region: z.string().default("us-east-1"),
|
||||
localFilePath: z.string()
|
||||
})
|
||||
.optional(),
|
||||
});
|
||||
|
||||
export function readPrivateConfigFile() {
|
||||
if (build == "oss") {
|
||||
return {};
|
||||
}
|
||||
|
||||
const loadConfig = (configPath: string) => {
|
||||
try {
|
||||
const yamlContent = fs.readFileSync(configPath, "utf8");
|
||||
const config = yaml.load(yamlContent);
|
||||
return config;
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(
|
||||
`Error loading configuration file: ${error.message}`
|
||||
);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
let environment: any;
|
||||
if (fs.existsSync(privateConfigFilePath1)) {
|
||||
environment = loadConfig(privateConfigFilePath1);
|
||||
}
|
||||
|
||||
if (!environment) {
|
||||
throw new Error(
|
||||
"No private configuration file found."
|
||||
);
|
||||
}
|
||||
|
||||
return environment;
|
||||
}
|
||||
782
server/private/lib/redis.ts
Normal file
782
server/private/lib/redis.ts
Normal file
@@ -0,0 +1,782 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import Redis, { RedisOptions } from "ioredis";
|
||||
import logger from "@server/logger";
|
||||
import privateConfig from "#private/lib/config";
|
||||
import { build } from "@server/build";
|
||||
|
||||
class RedisManager {
|
||||
public client: Redis | null = null;
|
||||
private writeClient: Redis | null = null; // Master for writes
|
||||
private readClient: Redis | null = null; // Replica for reads
|
||||
private subscriber: Redis | null = null;
|
||||
private publisher: Redis | null = null;
|
||||
private isEnabled: boolean = false;
|
||||
private isHealthy: boolean = true;
|
||||
private isWriteHealthy: boolean = true;
|
||||
private isReadHealthy: boolean = true;
|
||||
private lastHealthCheck: number = 0;
|
||||
private healthCheckInterval: number = 30000; // 30 seconds
|
||||
private connectionTimeout: number = 15000; // 15 seconds
|
||||
private commandTimeout: number = 15000; // 15 seconds
|
||||
private hasReplicas: boolean = false;
|
||||
private maxRetries: number = 3;
|
||||
private baseRetryDelay: number = 100; // 100ms
|
||||
private maxRetryDelay: number = 2000; // 2 seconds
|
||||
private backoffMultiplier: number = 2;
|
||||
private subscribers: Map<
|
||||
string,
|
||||
Set<(channel: string, message: string) => void>
|
||||
> = new Map();
|
||||
private reconnectionCallbacks: Set<() => Promise<void>> = new Set();
|
||||
|
||||
constructor() {
|
||||
if (build == "oss") {
|
||||
this.isEnabled = false;
|
||||
return;
|
||||
}
|
||||
this.isEnabled = privateConfig.getRawPrivateConfig().flags?.enable_redis || false;
|
||||
if (this.isEnabled) {
|
||||
this.initializeClients();
|
||||
}
|
||||
}
|
||||
|
||||
// Register callback to be called when Redis reconnects
|
||||
public onReconnection(callback: () => Promise<void>): void {
|
||||
this.reconnectionCallbacks.add(callback);
|
||||
}
|
||||
|
||||
// Unregister reconnection callback
|
||||
public offReconnection(callback: () => Promise<void>): void {
|
||||
this.reconnectionCallbacks.delete(callback);
|
||||
}
|
||||
|
||||
private async triggerReconnectionCallbacks(): Promise<void> {
|
||||
logger.info(`Triggering ${this.reconnectionCallbacks.size} reconnection callbacks`);
|
||||
|
||||
const promises = Array.from(this.reconnectionCallbacks).map(async (callback) => {
|
||||
try {
|
||||
await callback();
|
||||
} catch (error) {
|
||||
logger.error("Error in reconnection callback:", error);
|
||||
}
|
||||
});
|
||||
|
||||
await Promise.allSettled(promises);
|
||||
}
|
||||
|
||||
private async resubscribeToChannels(): Promise<void> {
|
||||
if (!this.subscriber || this.subscribers.size === 0) return;
|
||||
|
||||
logger.info(`Re-subscribing to ${this.subscribers.size} channels after Redis reconnection`);
|
||||
|
||||
try {
|
||||
const channels = Array.from(this.subscribers.keys());
|
||||
if (channels.length > 0) {
|
||||
await this.subscriber.subscribe(...channels);
|
||||
logger.info(`Successfully re-subscribed to channels: ${channels.join(', ')}`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Failed to re-subscribe to channels:", error);
|
||||
}
|
||||
}
|
||||
|
||||
private getRedisConfig(): RedisOptions {
|
||||
const redisConfig = privateConfig.getRawPrivateConfig().redis!;
|
||||
const opts: RedisOptions = {
|
||||
host: redisConfig.host!,
|
||||
port: redisConfig.port!,
|
||||
password: redisConfig.password,
|
||||
db: redisConfig.db,
|
||||
// tls: {
|
||||
// rejectUnauthorized:
|
||||
// redisConfig.tls?.reject_unauthorized || false
|
||||
// }
|
||||
};
|
||||
return opts;
|
||||
}
|
||||
|
||||
private getReplicaRedisConfig(): RedisOptions | null {
|
||||
const redisConfig = privateConfig.getRawPrivateConfig().redis!;
|
||||
if (!redisConfig.replicas || redisConfig.replicas.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Use the first replica for simplicity
|
||||
// In production, you might want to implement load balancing across replicas
|
||||
const replica = redisConfig.replicas[0];
|
||||
const opts: RedisOptions = {
|
||||
host: replica.host!,
|
||||
port: replica.port!,
|
||||
password: replica.password,
|
||||
db: replica.db || redisConfig.db,
|
||||
// tls: {
|
||||
// rejectUnauthorized:
|
||||
// replica.tls?.reject_unauthorized || false
|
||||
// }
|
||||
};
|
||||
return opts;
|
||||
}
|
||||
|
||||
// Add reconnection logic in initializeClients
|
||||
private initializeClients(): void {
|
||||
const masterConfig = this.getRedisConfig();
|
||||
const replicaConfig = this.getReplicaRedisConfig();
|
||||
|
||||
this.hasReplicas = replicaConfig !== null;
|
||||
|
||||
try {
|
||||
// Initialize master connection for writes
|
||||
this.writeClient = new Redis({
|
||||
...masterConfig,
|
||||
enableReadyCheck: false,
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
});
|
||||
|
||||
// Initialize replica connection for reads (if available)
|
||||
if (this.hasReplicas) {
|
||||
this.readClient = new Redis({
|
||||
...replicaConfig!,
|
||||
enableReadyCheck: false,
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
});
|
||||
} else {
|
||||
// Fallback to master for reads if no replicas
|
||||
this.readClient = this.writeClient;
|
||||
}
|
||||
|
||||
// Backward compatibility - point to write client
|
||||
this.client = this.writeClient;
|
||||
|
||||
// Publisher uses master (writes)
|
||||
this.publisher = new Redis({
|
||||
...masterConfig,
|
||||
enableReadyCheck: false,
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
});
|
||||
|
||||
// Subscriber uses replica if available (reads)
|
||||
this.subscriber = new Redis({
|
||||
...(this.hasReplicas ? replicaConfig! : masterConfig),
|
||||
enableReadyCheck: false,
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
});
|
||||
|
||||
// Add reconnection handlers for write client
|
||||
this.writeClient.on("error", (err) => {
|
||||
logger.error("Redis write client error:", err);
|
||||
this.isWriteHealthy = false;
|
||||
this.isHealthy = false;
|
||||
});
|
||||
|
||||
this.writeClient.on("reconnecting", () => {
|
||||
logger.info("Redis write client reconnecting...");
|
||||
this.isWriteHealthy = false;
|
||||
this.isHealthy = false;
|
||||
});
|
||||
|
||||
this.writeClient.on("ready", () => {
|
||||
logger.info("Redis write client ready");
|
||||
this.isWriteHealthy = true;
|
||||
this.updateOverallHealth();
|
||||
|
||||
// Trigger reconnection callbacks when Redis comes back online
|
||||
if (this.isHealthy) {
|
||||
this.triggerReconnectionCallbacks().catch(error => {
|
||||
logger.error("Error triggering reconnection callbacks:", error);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
this.writeClient.on("connect", () => {
|
||||
logger.info("Redis write client connected");
|
||||
});
|
||||
|
||||
// Add reconnection handlers for read client (if different from write)
|
||||
if (this.hasReplicas && this.readClient !== this.writeClient) {
|
||||
this.readClient.on("error", (err) => {
|
||||
logger.error("Redis read client error:", err);
|
||||
this.isReadHealthy = false;
|
||||
this.updateOverallHealth();
|
||||
});
|
||||
|
||||
this.readClient.on("reconnecting", () => {
|
||||
logger.info("Redis read client reconnecting...");
|
||||
this.isReadHealthy = false;
|
||||
this.updateOverallHealth();
|
||||
});
|
||||
|
||||
this.readClient.on("ready", () => {
|
||||
logger.info("Redis read client ready");
|
||||
this.isReadHealthy = true;
|
||||
this.updateOverallHealth();
|
||||
|
||||
// Trigger reconnection callbacks when Redis comes back online
|
||||
if (this.isHealthy) {
|
||||
this.triggerReconnectionCallbacks().catch(error => {
|
||||
logger.error("Error triggering reconnection callbacks:", error);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
this.readClient.on("connect", () => {
|
||||
logger.info("Redis read client connected");
|
||||
});
|
||||
} else {
|
||||
// If using same client for reads and writes
|
||||
this.isReadHealthy = this.isWriteHealthy;
|
||||
}
|
||||
|
||||
this.publisher.on("error", (err) => {
|
||||
logger.error("Redis publisher error:", err);
|
||||
});
|
||||
|
||||
this.publisher.on("ready", () => {
|
||||
logger.info("Redis publisher ready");
|
||||
});
|
||||
|
||||
this.publisher.on("connect", () => {
|
||||
logger.info("Redis publisher connected");
|
||||
});
|
||||
|
||||
this.subscriber.on("error", (err) => {
|
||||
logger.error("Redis subscriber error:", err);
|
||||
});
|
||||
|
||||
this.subscriber.on("ready", () => {
|
||||
logger.info("Redis subscriber ready");
|
||||
// Re-subscribe to all channels after reconnection
|
||||
this.resubscribeToChannels().catch((error: any) => {
|
||||
logger.error("Error re-subscribing to channels:", error);
|
||||
});
|
||||
});
|
||||
|
||||
this.subscriber.on("connect", () => {
|
||||
logger.info("Redis subscriber connected");
|
||||
});
|
||||
|
||||
// Set up message handler for subscriber
|
||||
this.subscriber.on(
|
||||
"message",
|
||||
(channel: string, message: string) => {
|
||||
const channelSubscribers = this.subscribers.get(channel);
|
||||
if (channelSubscribers) {
|
||||
channelSubscribers.forEach((callback) => {
|
||||
try {
|
||||
callback(channel, message);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error in subscriber callback for channel ${channel}:`,
|
||||
error
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
const setupMessage = this.hasReplicas
|
||||
? "Redis clients initialized successfully with replica support"
|
||||
: "Redis clients initialized successfully (single instance)";
|
||||
logger.info(setupMessage);
|
||||
|
||||
// Start periodic health monitoring
|
||||
this.startHealthMonitoring();
|
||||
} catch (error) {
|
||||
logger.error("Failed to initialize Redis clients:", error);
|
||||
this.isEnabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
private updateOverallHealth(): void {
|
||||
// Overall health is true if write is healthy and (read is healthy OR we don't have replicas)
|
||||
this.isHealthy = this.isWriteHealthy && (this.isReadHealthy || !this.hasReplicas);
|
||||
}
|
||||
|
||||
private async executeWithRetry<T>(
|
||||
operation: () => Promise<T>,
|
||||
operationName: string,
|
||||
fallbackOperation?: () => Promise<T>
|
||||
): Promise<T> {
|
||||
let lastError: Error | null = null;
|
||||
|
||||
for (let attempt = 0; attempt <= this.maxRetries; attempt++) {
|
||||
try {
|
||||
return await operation();
|
||||
} catch (error) {
|
||||
lastError = error as Error;
|
||||
|
||||
// If this is the last attempt, try fallback if available
|
||||
if (attempt === this.maxRetries && fallbackOperation) {
|
||||
try {
|
||||
logger.warn(`${operationName} primary operation failed, trying fallback`);
|
||||
return await fallbackOperation();
|
||||
} catch (fallbackError) {
|
||||
logger.error(`${operationName} fallback also failed:`, fallbackError);
|
||||
throw lastError;
|
||||
}
|
||||
}
|
||||
|
||||
// Don't retry on the last attempt
|
||||
if (attempt === this.maxRetries) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Calculate delay with exponential backoff
|
||||
const delay = Math.min(
|
||||
this.baseRetryDelay * Math.pow(this.backoffMultiplier, attempt),
|
||||
this.maxRetryDelay
|
||||
);
|
||||
|
||||
logger.warn(`${operationName} failed (attempt ${attempt + 1}/${this.maxRetries + 1}), retrying in ${delay}ms:`, error);
|
||||
|
||||
// Wait before retrying
|
||||
await new Promise(resolve => setTimeout(resolve, delay));
|
||||
}
|
||||
}
|
||||
|
||||
logger.error(`${operationName} failed after ${this.maxRetries + 1} attempts:`, lastError);
|
||||
throw lastError;
|
||||
}
|
||||
|
||||
private startHealthMonitoring(): void {
|
||||
if (!this.isEnabled) return;
|
||||
|
||||
// Check health every 30 seconds
|
||||
setInterval(async () => {
|
||||
try {
|
||||
await this.checkRedisHealth();
|
||||
} catch (error) {
|
||||
logger.error("Error during Redis health monitoring:", error);
|
||||
}
|
||||
}, this.healthCheckInterval);
|
||||
}
|
||||
|
||||
public isRedisEnabled(): boolean {
|
||||
return this.isEnabled && this.client !== null && this.isHealthy;
|
||||
}
|
||||
|
||||
private async checkRedisHealth(): Promise<boolean> {
|
||||
const now = Date.now();
|
||||
|
||||
// Only check health every 30 seconds
|
||||
if (now - this.lastHealthCheck < this.healthCheckInterval) {
|
||||
return this.isHealthy;
|
||||
}
|
||||
|
||||
this.lastHealthCheck = now;
|
||||
|
||||
if (!this.writeClient) {
|
||||
this.isHealthy = false;
|
||||
this.isWriteHealthy = false;
|
||||
this.isReadHealthy = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
// Check write client (master) health
|
||||
await Promise.race([
|
||||
this.writeClient.ping(),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Write client health check timeout')), 2000)
|
||||
)
|
||||
]);
|
||||
this.isWriteHealthy = true;
|
||||
|
||||
// Check read client health if it's different from write client
|
||||
if (this.hasReplicas && this.readClient && this.readClient !== this.writeClient) {
|
||||
try {
|
||||
await Promise.race([
|
||||
this.readClient.ping(),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Read client health check timeout')), 2000)
|
||||
)
|
||||
]);
|
||||
this.isReadHealthy = true;
|
||||
} catch (error) {
|
||||
logger.error("Redis read client health check failed:", error);
|
||||
this.isReadHealthy = false;
|
||||
}
|
||||
} else {
|
||||
this.isReadHealthy = this.isWriteHealthy;
|
||||
}
|
||||
|
||||
this.updateOverallHealth();
|
||||
return this.isHealthy;
|
||||
} catch (error) {
|
||||
logger.error("Redis write client health check failed:", error);
|
||||
this.isWriteHealthy = false;
|
||||
this.isReadHealthy = false; // If write fails, consider read as failed too for safety
|
||||
this.isHealthy = false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public getClient(): Redis {
|
||||
return this.client!;
|
||||
}
|
||||
|
||||
public getWriteClient(): Redis | null {
|
||||
return this.writeClient;
|
||||
}
|
||||
|
||||
public getReadClient(): Redis | null {
|
||||
return this.readClient;
|
||||
}
|
||||
|
||||
public hasReplicaSupport(): boolean {
|
||||
return this.hasReplicas;
|
||||
}
|
||||
|
||||
public getHealthStatus(): {
|
||||
isEnabled: boolean;
|
||||
isHealthy: boolean;
|
||||
isWriteHealthy: boolean;
|
||||
isReadHealthy: boolean;
|
||||
hasReplicas: boolean;
|
||||
} {
|
||||
return {
|
||||
isEnabled: this.isEnabled,
|
||||
isHealthy: this.isHealthy,
|
||||
isWriteHealthy: this.isWriteHealthy,
|
||||
isReadHealthy: this.isReadHealthy,
|
||||
hasReplicas: this.hasReplicas
|
||||
};
|
||||
}
|
||||
|
||||
public async set(
|
||||
key: string,
|
||||
value: string,
|
||||
ttl?: number
|
||||
): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
async () => {
|
||||
if (ttl) {
|
||||
await this.writeClient!.setex(key, ttl, value);
|
||||
} else {
|
||||
await this.writeClient!.set(key, value);
|
||||
}
|
||||
},
|
||||
"Redis SET"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis SET error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async get(key: string): Promise<string | null> {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return null;
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.get(key)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.get(key),
|
||||
"Redis GET",
|
||||
fallbackOperation
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Redis GET error:", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public async del(key: string): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
() => this.writeClient!.del(key),
|
||||
"Redis DEL"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis DEL error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async sadd(key: string, member: string): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
() => this.writeClient!.sadd(key, member),
|
||||
"Redis SADD"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis SADD error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async srem(key: string, member: string): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
() => this.writeClient!.srem(key, member),
|
||||
"Redis SREM"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis SREM error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async smembers(key: string): Promise<string[]> {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return [];
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.smembers(key)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.smembers(key),
|
||||
"Redis SMEMBERS",
|
||||
fallbackOperation
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Redis SMEMBERS error:", error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
public async hset(
|
||||
key: string,
|
||||
field: string,
|
||||
value: string
|
||||
): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
() => this.writeClient!.hset(key, field, value),
|
||||
"Redis HSET"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis HSET error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async hget(key: string, field: string): Promise<string | null> {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return null;
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.hget(key, field)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.hget(key, field),
|
||||
"Redis HGET",
|
||||
fallbackOperation
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Redis HGET error:", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public async hdel(key: string, field: string): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
() => this.writeClient!.hdel(key, field),
|
||||
"Redis HDEL"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis HDEL error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async hgetall(key: string): Promise<Record<string, string>> {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return {};
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.hgetall(key)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.hgetall(key),
|
||||
"Redis HGETALL",
|
||||
fallbackOperation
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Redis HGETALL error:", error);
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
public async publish(channel: string, message: string): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.publisher) return false;
|
||||
|
||||
// Quick health check before attempting to publish
|
||||
const isHealthy = await this.checkRedisHealth();
|
||||
if (!isHealthy) {
|
||||
logger.warn("Skipping Redis publish due to unhealthy connection");
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
async () => {
|
||||
// Add timeout to prevent hanging
|
||||
return Promise.race([
|
||||
this.publisher!.publish(channel, message),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Redis publish timeout')), 3000)
|
||||
)
|
||||
]);
|
||||
},
|
||||
"Redis PUBLISH"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis PUBLISH error:", error);
|
||||
this.isHealthy = false; // Mark as unhealthy on error
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async subscribe(
|
||||
channel: string,
|
||||
callback: (channel: string, message: string) => void
|
||||
): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.subscriber) return false;
|
||||
|
||||
try {
|
||||
// Add callback to subscribers map
|
||||
if (!this.subscribers.has(channel)) {
|
||||
this.subscribers.set(channel, new Set());
|
||||
// Only subscribe to the channel if it's the first subscriber
|
||||
await this.executeWithRetry(
|
||||
async () => {
|
||||
return Promise.race([
|
||||
this.subscriber!.subscribe(channel),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Redis subscribe timeout')), 5000)
|
||||
)
|
||||
]);
|
||||
},
|
||||
"Redis SUBSCRIBE"
|
||||
);
|
||||
}
|
||||
|
||||
this.subscribers.get(channel)!.add(callback);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis SUBSCRIBE error:", error);
|
||||
this.isHealthy = false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async unsubscribe(
|
||||
channel: string,
|
||||
callback?: (channel: string, message: string) => void
|
||||
): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.subscriber) return false;
|
||||
|
||||
try {
|
||||
const channelSubscribers = this.subscribers.get(channel);
|
||||
if (!channelSubscribers) return true;
|
||||
|
||||
if (callback) {
|
||||
// Remove specific callback
|
||||
channelSubscribers.delete(callback);
|
||||
if (channelSubscribers.size === 0) {
|
||||
this.subscribers.delete(channel);
|
||||
await this.executeWithRetry(
|
||||
() => this.subscriber!.unsubscribe(channel),
|
||||
"Redis UNSUBSCRIBE"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Remove all callbacks for this channel
|
||||
this.subscribers.delete(channel);
|
||||
await this.executeWithRetry(
|
||||
() => this.subscriber!.unsubscribe(channel),
|
||||
"Redis UNSUBSCRIBE"
|
||||
);
|
||||
}
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis UNSUBSCRIBE error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async disconnect(): Promise<void> {
|
||||
try {
|
||||
if (this.client) {
|
||||
await this.client.quit();
|
||||
this.client = null;
|
||||
}
|
||||
if (this.writeClient) {
|
||||
await this.writeClient.quit();
|
||||
this.writeClient = null;
|
||||
}
|
||||
if (this.readClient && this.readClient !== this.writeClient) {
|
||||
await this.readClient.quit();
|
||||
this.readClient = null;
|
||||
}
|
||||
if (this.publisher) {
|
||||
await this.publisher.quit();
|
||||
this.publisher = null;
|
||||
}
|
||||
if (this.subscriber) {
|
||||
await this.subscriber.quit();
|
||||
this.subscriber = null;
|
||||
}
|
||||
this.subscribers.clear();
|
||||
logger.info("Redis clients disconnected");
|
||||
} catch (error) {
|
||||
logger.error("Error disconnecting Redis clients:", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const redisManager = new RedisManager();
|
||||
export const redis = redisManager.getClient();
|
||||
export default redisManager;
|
||||
223
server/private/lib/redisStore.ts
Normal file
223
server/private/lib/redisStore.ts
Normal file
@@ -0,0 +1,223 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Store, Options, IncrementResponse } from 'express-rate-limit';
|
||||
import { rateLimitService } from './rateLimit';
|
||||
import logger from '@server/logger';
|
||||
|
||||
/**
|
||||
* A Redis-backed rate limiting store for express-rate-limit that optimizes
|
||||
* for local read performance and batched writes to Redis.
|
||||
*
|
||||
* This store uses the same optimized rate limiting logic as the WebSocket
|
||||
* implementation, providing:
|
||||
* - Local caching for fast reads
|
||||
* - Batched writes to Redis to reduce load
|
||||
* - Automatic cleanup of expired entries
|
||||
* - Graceful fallback when Redis is unavailable
|
||||
*/
|
||||
export default class RedisStore implements Store {
|
||||
/**
|
||||
* The duration of time before which all hit counts are reset (in milliseconds).
|
||||
*/
|
||||
windowMs!: number;
|
||||
|
||||
/**
|
||||
* Maximum number of requests allowed within the window.
|
||||
*/
|
||||
max!: number;
|
||||
|
||||
/**
|
||||
* Optional prefix for Redis keys to avoid collisions.
|
||||
*/
|
||||
prefix: string;
|
||||
|
||||
/**
|
||||
* Whether to skip incrementing on failed requests.
|
||||
*/
|
||||
skipFailedRequests: boolean;
|
||||
|
||||
/**
|
||||
* Whether to skip incrementing on successful requests.
|
||||
*/
|
||||
skipSuccessfulRequests: boolean;
|
||||
|
||||
/**
|
||||
* @constructor for RedisStore.
|
||||
*
|
||||
* @param options - Configuration options for the store.
|
||||
*/
|
||||
constructor(options: {
|
||||
prefix?: string;
|
||||
skipFailedRequests?: boolean;
|
||||
skipSuccessfulRequests?: boolean;
|
||||
} = {}) {
|
||||
this.prefix = options.prefix || 'express-rate-limit';
|
||||
this.skipFailedRequests = options.skipFailedRequests || false;
|
||||
this.skipSuccessfulRequests = options.skipSuccessfulRequests || false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Method that actually initializes the store. Must be synchronous.
|
||||
*
|
||||
* @param options - The options used to setup express-rate-limit.
|
||||
*/
|
||||
init(options: Options): void {
|
||||
this.windowMs = options.windowMs;
|
||||
this.max = options.max as number;
|
||||
|
||||
// logger.debug(`RedisStore initialized with windowMs: ${this.windowMs}, max: ${this.max}, prefix: ${this.prefix}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to increment a client's hit counter.
|
||||
*
|
||||
* @param key - The identifier for a client (usually IP address).
|
||||
* @returns Promise resolving to the number of hits and reset time for that client.
|
||||
*/
|
||||
async increment(key: string): Promise<IncrementResponse> {
|
||||
try {
|
||||
const clientId = `${this.prefix}:${key}`;
|
||||
|
||||
const result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
undefined, // No message type for HTTP requests
|
||||
this.max,
|
||||
undefined, // No message type limit
|
||||
this.windowMs
|
||||
);
|
||||
|
||||
// logger.debug(`Incremented rate limit for key: ${key} with max: ${this.max}, totalHits: ${result.totalHits}`);
|
||||
|
||||
return {
|
||||
totalHits: result.totalHits || 1,
|
||||
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`RedisStore increment error for key ${key}:`, error);
|
||||
|
||||
// Return safe defaults on error to prevent blocking requests
|
||||
return {
|
||||
totalHits: 1,
|
||||
resetTime: new Date(Date.now() + this.windowMs)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to decrement a client's hit counter.
|
||||
* Used when skipSuccessfulRequests or skipFailedRequests is enabled.
|
||||
*
|
||||
* @param key - The identifier for a client.
|
||||
*/
|
||||
async decrement(key: string): Promise<void> {
|
||||
try {
|
||||
const clientId = `${this.prefix}:${key}`;
|
||||
await rateLimitService.decrementRateLimit(clientId);
|
||||
|
||||
// logger.debug(`Decremented rate limit for key: ${key}`);
|
||||
} catch (error) {
|
||||
logger.error(`RedisStore decrement error for key ${key}:`, error);
|
||||
// Don't throw - decrement failures shouldn't block requests
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to reset a client's hit counter.
|
||||
*
|
||||
* @param key - The identifier for a client.
|
||||
*/
|
||||
async resetKey(key: string): Promise<void> {
|
||||
try {
|
||||
const clientId = `${this.prefix}:${key}`;
|
||||
await rateLimitService.resetKey(clientId);
|
||||
|
||||
// logger.debug(`Reset rate limit for key: ${key}`);
|
||||
} catch (error) {
|
||||
logger.error(`RedisStore resetKey error for key ${key}:`, error);
|
||||
// Don't throw - reset failures shouldn't block requests
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to reset everyone's hit counter.
|
||||
*
|
||||
* This method is optional and is never called by express-rate-limit.
|
||||
* We implement it for completeness but it's not recommended for production use
|
||||
* as it could be expensive with large datasets.
|
||||
*/
|
||||
async resetAll(): Promise<void> {
|
||||
try {
|
||||
logger.warn('RedisStore resetAll called - this operation can be expensive');
|
||||
|
||||
// Force sync all pending data first
|
||||
await rateLimitService.forceSyncAllPendingData();
|
||||
|
||||
// Note: We don't actually implement full reset as it would require
|
||||
// scanning all Redis keys with our prefix, which could be expensive.
|
||||
// In production, it's better to let entries expire naturally.
|
||||
|
||||
logger.info('RedisStore resetAll completed (pending data synced)');
|
||||
} catch (error) {
|
||||
logger.error('RedisStore resetAll error:', error);
|
||||
// Don't throw - this is an optional method
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current hit count for a key without incrementing.
|
||||
* This is a custom method not part of the Store interface.
|
||||
*
|
||||
* @param key - The identifier for a client.
|
||||
* @returns Current hit count and reset time, or null if no data exists.
|
||||
*/
|
||||
async getHits(key: string): Promise<{ totalHits: number; resetTime: Date } | null> {
|
||||
try {
|
||||
const clientId = `${this.prefix}:${key}`;
|
||||
|
||||
// Use checkRateLimit with max + 1 to avoid actually incrementing
|
||||
// but still get the current count
|
||||
const result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
undefined,
|
||||
this.max + 1000, // Set artificially high to avoid triggering limit
|
||||
undefined,
|
||||
this.windowMs
|
||||
);
|
||||
|
||||
// Decrement since we don't actually want to count this check
|
||||
await rateLimitService.decrementRateLimit(clientId);
|
||||
|
||||
return {
|
||||
totalHits: Math.max(0, (result.totalHits || 0) - 1), // Adjust for the decrement
|
||||
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`RedisStore getHits error for key ${key}:`, error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup method for graceful shutdown.
|
||||
* This is not part of the Store interface but is useful for cleanup.
|
||||
*/
|
||||
async shutdown(): Promise<void> {
|
||||
try {
|
||||
// The rateLimitService handles its own cleanup
|
||||
logger.info('RedisStore shutdown completed');
|
||||
} catch (error) {
|
||||
logger.error('RedisStore shutdown error:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
124
server/private/lib/resend.ts
Normal file
124
server/private/lib/resend.ts
Normal file
@@ -0,0 +1,124 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Resend } from "resend";
|
||||
import privateConfig from "#private/lib/config";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export enum AudienceIds {
|
||||
General = "5cfbf99b-c592-40a9-9b8a-577a4681c158",
|
||||
Subscribed = "870b43fd-387f-44de-8fc1-707335f30b20",
|
||||
Churned = "f3ae92bd-2fdb-4d77-8746-2118afd62549"
|
||||
}
|
||||
|
||||
const resend = new Resend(
|
||||
privateConfig.getRawPrivateConfig().server.resend_api_key || "missing"
|
||||
);
|
||||
|
||||
export default resend;
|
||||
|
||||
export async function moveEmailToAudience(
|
||||
email: string,
|
||||
audienceId: AudienceIds
|
||||
) {
|
||||
if (process.env.ENVIRONMENT !== "prod") {
|
||||
logger.debug(`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`);
|
||||
return;
|
||||
}
|
||||
const { error, data } = await retryWithBackoff(async () => {
|
||||
const { data, error } = await resend.contacts.create({
|
||||
email,
|
||||
unsubscribed: false,
|
||||
audienceId
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(
|
||||
`Error adding email ${email} to audience ${audienceId}: ${error}`
|
||||
);
|
||||
}
|
||||
return { error, data };
|
||||
});
|
||||
|
||||
if (error) {
|
||||
logger.error(
|
||||
`Error adding email ${email} to audience ${audienceId}: ${error}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (data) {
|
||||
logger.debug(
|
||||
`Added email ${email} to audience ${audienceId} with contact ID ${data.id}`
|
||||
);
|
||||
}
|
||||
|
||||
const otherAudiences = Object.values(AudienceIds).filter(
|
||||
(id) => id !== audienceId
|
||||
);
|
||||
|
||||
for (const otherAudienceId of otherAudiences) {
|
||||
const { error, data } = await retryWithBackoff(async () => {
|
||||
const { data, error } = await resend.contacts.remove({
|
||||
email,
|
||||
audienceId: otherAudienceId
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(
|
||||
`Error removing email ${email} from audience ${otherAudienceId}: ${error}`
|
||||
);
|
||||
}
|
||||
return { error, data };
|
||||
});
|
||||
|
||||
if (error) {
|
||||
logger.error(
|
||||
`Error removing email ${email} from audience ${otherAudienceId}: ${error}`
|
||||
);
|
||||
}
|
||||
|
||||
if (data) {
|
||||
logger.info(
|
||||
`Removed email ${email} from audience ${otherAudienceId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type RetryOptions = {
|
||||
retries?: number;
|
||||
initialDelayMs?: number;
|
||||
factor?: number;
|
||||
};
|
||||
|
||||
export async function retryWithBackoff<T>(
|
||||
fn: () => Promise<T>,
|
||||
options: RetryOptions = {}
|
||||
): Promise<T> {
|
||||
const { retries = 5, initialDelayMs = 500, factor = 2 } = options;
|
||||
|
||||
let attempt = 0;
|
||||
let delay = initialDelayMs;
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
return await fn();
|
||||
} catch (err) {
|
||||
attempt++;
|
||||
|
||||
if (attempt > retries) throw err;
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
delay *= factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
28
server/private/lib/stripe.ts
Normal file
28
server/private/lib/stripe.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import Stripe from "stripe";
|
||||
import privateConfig from "#private/lib/config";
|
||||
import logger from "@server/logger";
|
||||
import { build } from "@server/build";
|
||||
|
||||
let stripe: Stripe | undefined = undefined;
|
||||
if (build == "saas") {
|
||||
const stripeApiKey = privateConfig.getRawPrivateConfig().stripe?.secret_key;
|
||||
if (!stripeApiKey) {
|
||||
logger.error("Stripe secret key is not configured");
|
||||
}
|
||||
stripe = new Stripe(stripeApiKey!);
|
||||
}
|
||||
|
||||
export default stripe;
|
||||
712
server/private/lib/traefik/getTraefikConfig.ts
Normal file
712
server/private/lib/traefik/getTraefikConfig.ts
Normal file
@@ -0,0 +1,712 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import {
|
||||
certificates,
|
||||
db,
|
||||
domainNamespaces,
|
||||
exitNodes,
|
||||
loginPage,
|
||||
targetHealthCheck
|
||||
} from "@server/db";
|
||||
import { and, eq, inArray, or, isNull, ne, isNotNull, desc } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import config from "@server/lib/config";
|
||||
import { orgs, resources, sites, Target, targets } from "@server/db";
|
||||
import { build } from "@server/build";
|
||||
import { sanitize } from "@server/lib/traefik/utils";
|
||||
|
||||
const redirectHttpsMiddlewareName = "redirect-to-https";
|
||||
const redirectToRootMiddlewareName = "redirect-to-root";
|
||||
const badgerMiddlewareName = "badger";
|
||||
|
||||
export async function getTraefikConfig(
|
||||
exitNodeId: number,
|
||||
siteTypes: string[],
|
||||
filterOutNamespaceDomains = false,
|
||||
generateLoginPageRouters = false
|
||||
): Promise<any> {
|
||||
// Define extended target type with site information
|
||||
type TargetWithSite = Target & {
|
||||
site: {
|
||||
siteId: number;
|
||||
type: string;
|
||||
subnet: string | null;
|
||||
exitNodeId: number | null;
|
||||
online: boolean;
|
||||
};
|
||||
};
|
||||
|
||||
// Get resources with their targets and sites in a single optimized query
|
||||
// Start from sites on this exit node, then join to targets and resources
|
||||
const resourcesWithTargetsAndSites = await db
|
||||
.select({
|
||||
// Resource fields
|
||||
resourceId: resources.resourceId,
|
||||
resourceName: resources.name,
|
||||
fullDomain: resources.fullDomain,
|
||||
ssl: resources.ssl,
|
||||
http: resources.http,
|
||||
proxyPort: resources.proxyPort,
|
||||
protocol: resources.protocol,
|
||||
subdomain: resources.subdomain,
|
||||
domainId: resources.domainId,
|
||||
enabled: resources.enabled,
|
||||
stickySession: resources.stickySession,
|
||||
tlsServerName: resources.tlsServerName,
|
||||
setHostHeader: resources.setHostHeader,
|
||||
enableProxy: resources.enableProxy,
|
||||
headers: resources.headers,
|
||||
// Target fields
|
||||
targetId: targets.targetId,
|
||||
targetEnabled: targets.enabled,
|
||||
ip: targets.ip,
|
||||
method: targets.method,
|
||||
port: targets.port,
|
||||
internalPort: targets.internalPort,
|
||||
hcHealth: targetHealthCheck.hcHealth,
|
||||
path: targets.path,
|
||||
pathMatchType: targets.pathMatchType,
|
||||
priority: targets.priority,
|
||||
|
||||
// Site fields
|
||||
siteId: sites.siteId,
|
||||
siteType: sites.type,
|
||||
siteOnline: sites.online,
|
||||
subnet: sites.subnet,
|
||||
exitNodeId: sites.exitNodeId,
|
||||
// Namespace
|
||||
domainNamespaceId: domainNamespaces.domainNamespaceId,
|
||||
// Certificate
|
||||
certificateStatus: certificates.status
|
||||
})
|
||||
.from(sites)
|
||||
.innerJoin(targets, eq(targets.siteId, sites.siteId))
|
||||
.innerJoin(resources, eq(resources.resourceId, targets.resourceId))
|
||||
.leftJoin(certificates, eq(certificates.domainId, resources.domainId))
|
||||
.leftJoin(
|
||||
targetHealthCheck,
|
||||
eq(targetHealthCheck.targetId, targets.targetId)
|
||||
)
|
||||
.leftJoin(
|
||||
domainNamespaces,
|
||||
eq(domainNamespaces.domainId, resources.domainId)
|
||||
) // THIS IS CLOUD ONLY TO FILTER OUT THE DOMAIN NAMESPACES IF REQUIRED
|
||||
.where(
|
||||
and(
|
||||
eq(targets.enabled, true),
|
||||
eq(resources.enabled, true),
|
||||
// or(
|
||||
eq(sites.exitNodeId, exitNodeId),
|
||||
// isNull(sites.exitNodeId)
|
||||
// ),
|
||||
or(
|
||||
ne(targetHealthCheck.hcHealth, "unhealthy"), // Exclude unhealthy targets
|
||||
isNull(targetHealthCheck.hcHealth) // Include targets with no health check record
|
||||
),
|
||||
inArray(sites.type, siteTypes),
|
||||
config.getRawConfig().traefik.allow_raw_resources
|
||||
? isNotNull(resources.http) // ignore the http check if allow_raw_resources is true
|
||||
: eq(resources.http, true)
|
||||
)
|
||||
)
|
||||
.orderBy(desc(targets.priority), targets.targetId); // stable ordering
|
||||
|
||||
// Group by resource and include targets with their unique site data
|
||||
const resourcesMap = new Map();
|
||||
|
||||
resourcesWithTargetsAndSites.forEach((row) => {
|
||||
const resourceId = row.resourceId;
|
||||
const resourceName = sanitize(row.resourceName) || "";
|
||||
const targetPath = sanitize(row.path) || ""; // Handle null/undefined paths
|
||||
const pathMatchType = row.pathMatchType || "";
|
||||
const priority = row.priority ?? 100;
|
||||
|
||||
if (filterOutNamespaceDomains && row.domainNamespaceId) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Create a unique key combining resourceId and path+pathMatchType
|
||||
const pathKey = [targetPath, pathMatchType].filter(Boolean).join("-");
|
||||
const mapKey = [resourceId, pathKey].filter(Boolean).join("-");
|
||||
const key = sanitize(mapKey);
|
||||
|
||||
if (!resourcesMap.has(key)) {
|
||||
resourcesMap.set(key, {
|
||||
resourceId: row.resourceId,
|
||||
name: resourceName,
|
||||
fullDomain: row.fullDomain,
|
||||
ssl: row.ssl,
|
||||
http: row.http,
|
||||
proxyPort: row.proxyPort,
|
||||
protocol: row.protocol,
|
||||
subdomain: row.subdomain,
|
||||
domainId: row.domainId,
|
||||
enabled: row.enabled,
|
||||
stickySession: row.stickySession,
|
||||
tlsServerName: row.tlsServerName,
|
||||
setHostHeader: row.setHostHeader,
|
||||
enableProxy: row.enableProxy,
|
||||
certificateStatus: row.certificateStatus,
|
||||
targets: [],
|
||||
headers: row.headers,
|
||||
path: row.path, // the targets will all have the same path
|
||||
pathMatchType: row.pathMatchType, // the targets will all have the same pathMatchType
|
||||
priority: priority // may be null, we fallback later
|
||||
});
|
||||
}
|
||||
|
||||
// Add target with its associated site data
|
||||
resourcesMap.get(key).targets.push({
|
||||
resourceId: row.resourceId,
|
||||
targetId: row.targetId,
|
||||
ip: row.ip,
|
||||
method: row.method,
|
||||
port: row.port,
|
||||
internalPort: row.internalPort,
|
||||
enabled: row.targetEnabled,
|
||||
priority: row.priority,
|
||||
site: {
|
||||
siteId: row.siteId,
|
||||
type: row.siteType,
|
||||
subnet: row.subnet,
|
||||
exitNodeId: row.exitNodeId,
|
||||
online: row.siteOnline
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// make sure we have at least one resource
|
||||
if (resourcesMap.size === 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const config_output: any = {
|
||||
http: {
|
||||
middlewares: {
|
||||
[redirectHttpsMiddlewareName]: {
|
||||
redirectScheme: {
|
||||
scheme: "https"
|
||||
}
|
||||
},
|
||||
[redirectToRootMiddlewareName]: {
|
||||
redirectRegex: {
|
||||
regex: "^(https?)://([^/]+)(/.*)?",
|
||||
replacement: "${1}://${2}/auth/org",
|
||||
permanent: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// get the key and the resource
|
||||
for (const [key, resource] of resourcesMap.entries()) {
|
||||
const targets = resource.targets;
|
||||
|
||||
const routerName = `${key}-${resource.name}-router`;
|
||||
const serviceName = `${key}-${resource.name}-service`;
|
||||
const fullDomain = `${resource.fullDomain}`;
|
||||
const transportName = `${key}-transport`;
|
||||
const headersMiddlewareName = `${key}-headers-middleware`;
|
||||
|
||||
if (!resource.enabled) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (resource.http) {
|
||||
if (!resource.domainId) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!resource.fullDomain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (resource.certificateStatus !== "valid") {
|
||||
logger.debug(
|
||||
`Resource ${resource.resourceId} has certificate stats ${resource.certificateStats}`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// add routers and services empty objects if they don't exist
|
||||
if (!config_output.http.routers) {
|
||||
config_output.http.routers = {};
|
||||
}
|
||||
|
||||
if (!config_output.http.services) {
|
||||
config_output.http.services = {};
|
||||
}
|
||||
|
||||
const domainParts = fullDomain.split(".");
|
||||
let wildCard;
|
||||
if (domainParts.length <= 2) {
|
||||
wildCard = `*.${domainParts.join(".")}`;
|
||||
} else {
|
||||
wildCard = `*.${domainParts.slice(1).join(".")}`;
|
||||
}
|
||||
|
||||
if (!resource.subdomain) {
|
||||
wildCard = resource.fullDomain;
|
||||
}
|
||||
|
||||
const configDomain = config.getDomain(resource.domainId);
|
||||
|
||||
let certResolver: string, preferWildcardCert: boolean;
|
||||
if (!configDomain) {
|
||||
certResolver = config.getRawConfig().traefik.cert_resolver;
|
||||
preferWildcardCert =
|
||||
config.getRawConfig().traefik.prefer_wildcard_cert;
|
||||
} else {
|
||||
certResolver = configDomain.cert_resolver;
|
||||
preferWildcardCert = configDomain.prefer_wildcard_cert;
|
||||
}
|
||||
|
||||
let tls = {};
|
||||
if (build == "oss") {
|
||||
tls = {
|
||||
certResolver: certResolver,
|
||||
...(preferWildcardCert
|
||||
? {
|
||||
domains: [
|
||||
{
|
||||
main: wildCard
|
||||
}
|
||||
]
|
||||
}
|
||||
: {})
|
||||
};
|
||||
}
|
||||
|
||||
const additionalMiddlewares =
|
||||
config.getRawConfig().traefik.additional_middlewares || [];
|
||||
|
||||
const routerMiddlewares = [
|
||||
badgerMiddlewareName,
|
||||
...additionalMiddlewares
|
||||
];
|
||||
|
||||
if (resource.headers || resource.setHostHeader) {
|
||||
// if there are headers, parse them into an object
|
||||
const headersObj: { [key: string]: string } = {};
|
||||
if (resource.headers) {
|
||||
let headersArr: { name: string; value: string }[] = [];
|
||||
try {
|
||||
headersArr = JSON.parse(resource.headers) as {
|
||||
name: string;
|
||||
value: string;
|
||||
}[];
|
||||
} catch (e) {
|
||||
logger.warn(
|
||||
`Failed to parse headers for resource ${resource.resourceId}: ${e}`
|
||||
);
|
||||
}
|
||||
|
||||
headersArr.forEach((header) => {
|
||||
headersObj[header.name] = header.value;
|
||||
});
|
||||
}
|
||||
|
||||
if (resource.setHostHeader) {
|
||||
headersObj["Host"] = resource.setHostHeader;
|
||||
}
|
||||
|
||||
// check if the object is not empty
|
||||
if (Object.keys(headersObj).length > 0) {
|
||||
// Add the headers middleware
|
||||
if (!config_output.http.middlewares) {
|
||||
config_output.http.middlewares = {};
|
||||
}
|
||||
config_output.http.middlewares[headersMiddlewareName] = {
|
||||
headers: {
|
||||
customRequestHeaders: headersObj
|
||||
}
|
||||
};
|
||||
|
||||
routerMiddlewares.push(headersMiddlewareName);
|
||||
}
|
||||
}
|
||||
|
||||
let rule = `Host(\`${fullDomain}\`)`;
|
||||
|
||||
// priority logic
|
||||
let priority: number;
|
||||
if (resource.priority && resource.priority != 100) {
|
||||
priority = resource.priority;
|
||||
} else {
|
||||
priority = 100;
|
||||
if (resource.path && resource.pathMatchType) {
|
||||
priority += 10;
|
||||
if (resource.pathMatchType === "exact") {
|
||||
priority += 5;
|
||||
} else if (resource.pathMatchType === "prefix") {
|
||||
priority += 3;
|
||||
} else if (resource.pathMatchType === "regex") {
|
||||
priority += 2;
|
||||
}
|
||||
if (resource.path === "/") {
|
||||
priority = 1; // lowest for catch-all
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (resource.path && resource.pathMatchType) {
|
||||
//priority += 1;
|
||||
// add path to rule based on match type
|
||||
let path = resource.path;
|
||||
// if the path doesn't start with a /, add it
|
||||
if (!path.startsWith("/")) {
|
||||
path = `/${path}`;
|
||||
}
|
||||
if (resource.pathMatchType === "exact") {
|
||||
rule += ` && Path(\`${path}\`)`;
|
||||
} else if (resource.pathMatchType === "prefix") {
|
||||
rule += ` && PathPrefix(\`${path}\`)`;
|
||||
} else if (resource.pathMatchType === "regex") {
|
||||
rule += ` && PathRegexp(\`${resource.path}\`)`; // this is the raw path because it's a regex
|
||||
}
|
||||
}
|
||||
|
||||
config_output.http.routers![routerName] = {
|
||||
entryPoints: [
|
||||
resource.ssl
|
||||
? config.getRawConfig().traefik.https_entrypoint
|
||||
: config.getRawConfig().traefik.http_entrypoint
|
||||
],
|
||||
middlewares: routerMiddlewares,
|
||||
service: serviceName,
|
||||
rule: rule,
|
||||
priority: priority,
|
||||
...(resource.ssl ? { tls } : {})
|
||||
};
|
||||
|
||||
if (resource.ssl) {
|
||||
config_output.http.routers![routerName + "-redirect"] = {
|
||||
entryPoints: [
|
||||
config.getRawConfig().traefik.http_entrypoint
|
||||
],
|
||||
middlewares: [redirectHttpsMiddlewareName],
|
||||
service: serviceName,
|
||||
rule: rule,
|
||||
priority: priority
|
||||
};
|
||||
}
|
||||
|
||||
config_output.http.services![serviceName] = {
|
||||
loadBalancer: {
|
||||
servers: (() => {
|
||||
// Check if any sites are online
|
||||
// THIS IS SO THAT THERE IS SOME IMMEDIATE FEEDBACK
|
||||
// EVEN IF THE SITES HAVE NOT UPDATED YET FROM THE
|
||||
// RECEIVE BANDWIDTH ENDPOINT.
|
||||
|
||||
// TODO: HOW TO HANDLE ^^^^^^ BETTER
|
||||
const anySitesOnline = (
|
||||
targets as TargetWithSite[]
|
||||
).some((target: TargetWithSite) => target.site.online);
|
||||
|
||||
return (
|
||||
(targets as TargetWithSite[])
|
||||
.filter((target: TargetWithSite) => {
|
||||
if (!target.enabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If any sites are online, exclude offline sites
|
||||
if (anySitesOnline && !target.site.online) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
target.site.type === "local" ||
|
||||
target.site.type === "wireguard"
|
||||
) {
|
||||
if (
|
||||
!target.ip ||
|
||||
!target.port ||
|
||||
!target.method
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
} else if (target.site.type === "newt") {
|
||||
if (
|
||||
!target.internalPort ||
|
||||
!target.method ||
|
||||
!target.site.subnet
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.map((target: TargetWithSite) => {
|
||||
if (
|
||||
target.site.type === "local" ||
|
||||
target.site.type === "wireguard"
|
||||
) {
|
||||
return {
|
||||
url: `${target.method}://${target.ip}:${target.port}`
|
||||
};
|
||||
} else if (target.site.type === "newt") {
|
||||
const ip =
|
||||
target.site.subnet!.split("/")[0];
|
||||
return {
|
||||
url: `${target.method}://${ip}:${target.internalPort}`
|
||||
};
|
||||
}
|
||||
})
|
||||
// filter out duplicates
|
||||
.filter(
|
||||
(v, i, a) =>
|
||||
a.findIndex(
|
||||
(t) => t && v && t.url === v.url
|
||||
) === i
|
||||
)
|
||||
);
|
||||
})(),
|
||||
...(resource.stickySession
|
||||
? {
|
||||
sticky: {
|
||||
cookie: {
|
||||
name: "p_sticky", // TODO: make this configurable via config.yml like other cookies
|
||||
secure: resource.ssl,
|
||||
httpOnly: true
|
||||
}
|
||||
}
|
||||
}
|
||||
: {})
|
||||
}
|
||||
};
|
||||
|
||||
// Add the serversTransport if TLS server name is provided
|
||||
if (resource.tlsServerName) {
|
||||
if (!config_output.http.serversTransports) {
|
||||
config_output.http.serversTransports = {};
|
||||
}
|
||||
config_output.http.serversTransports![transportName] = {
|
||||
serverName: resource.tlsServerName,
|
||||
//unfortunately the following needs to be set. traefik doesn't merge the default serverTransport settings
|
||||
// if defined in the static config and here. if not set, self-signed certs won't work
|
||||
insecureSkipVerify: true
|
||||
};
|
||||
config_output.http.services![
|
||||
serviceName
|
||||
].loadBalancer.serversTransport = transportName;
|
||||
}
|
||||
} else {
|
||||
// Non-HTTP (TCP/UDP) configuration
|
||||
if (!resource.enableProxy) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const protocol = resource.protocol.toLowerCase();
|
||||
const port = resource.proxyPort;
|
||||
|
||||
if (!port) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!config_output[protocol]) {
|
||||
config_output[protocol] = {
|
||||
routers: {},
|
||||
services: {}
|
||||
};
|
||||
}
|
||||
|
||||
config_output[protocol].routers[routerName] = {
|
||||
entryPoints: [`${protocol}-${port}`],
|
||||
service: serviceName,
|
||||
...(protocol === "tcp" ? { rule: "HostSNI(`*`)" } : {})
|
||||
};
|
||||
|
||||
config_output[protocol].services[serviceName] = {
|
||||
loadBalancer: {
|
||||
servers: (() => {
|
||||
// Check if any sites are online
|
||||
const anySitesOnline = (
|
||||
targets as TargetWithSite[]
|
||||
).some((target: TargetWithSite) => target.site.online);
|
||||
|
||||
return (targets as TargetWithSite[])
|
||||
.filter((target: TargetWithSite) => {
|
||||
if (!target.enabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If any sites are online, exclude offline sites
|
||||
if (anySitesOnline && !target.site.online) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
target.site.type === "local" ||
|
||||
target.site.type === "wireguard"
|
||||
) {
|
||||
if (!target.ip || !target.port) {
|
||||
return false;
|
||||
}
|
||||
} else if (target.site.type === "newt") {
|
||||
if (
|
||||
!target.internalPort ||
|
||||
!target.site.subnet
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.map((target: TargetWithSite) => {
|
||||
if (
|
||||
target.site.type === "local" ||
|
||||
target.site.type === "wireguard"
|
||||
) {
|
||||
return {
|
||||
address: `${target.ip}:${target.port}`
|
||||
};
|
||||
} else if (target.site.type === "newt") {
|
||||
const ip =
|
||||
target.site.subnet!.split("/")[0];
|
||||
return {
|
||||
address: `${ip}:${target.internalPort}`
|
||||
};
|
||||
}
|
||||
});
|
||||
})(),
|
||||
...(resource.stickySession
|
||||
? {
|
||||
sticky: {
|
||||
ipStrategy: {
|
||||
depth: 0,
|
||||
sourcePort: true
|
||||
}
|
||||
}
|
||||
}
|
||||
: {})
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (generateLoginPageRouters) {
|
||||
const exitNodeLoginPages = await db
|
||||
.select({
|
||||
loginPageId: loginPage.loginPageId,
|
||||
fullDomain: loginPage.fullDomain,
|
||||
exitNodeId: exitNodes.exitNodeId,
|
||||
domainId: loginPage.domainId,
|
||||
certificateStatus: certificates.status
|
||||
})
|
||||
.from(loginPage)
|
||||
.innerJoin(
|
||||
exitNodes,
|
||||
eq(exitNodes.exitNodeId, loginPage.exitNodeId)
|
||||
)
|
||||
.leftJoin(
|
||||
certificates,
|
||||
eq(certificates.domainId, loginPage.domainId)
|
||||
)
|
||||
.where(eq(exitNodes.exitNodeId, exitNodeId));
|
||||
|
||||
if (exitNodeLoginPages.length > 0) {
|
||||
if (!config_output.http.services) {
|
||||
config_output.http.services = {};
|
||||
}
|
||||
|
||||
if (!config_output.http.services["landing-service"]) {
|
||||
config_output.http.services["landing-service"] = {
|
||||
loadBalancer: {
|
||||
servers: [
|
||||
{
|
||||
url: `http://${
|
||||
config.getRawConfig().server
|
||||
.internal_hostname
|
||||
}:${config.getRawConfig().server.next_port}`
|
||||
}
|
||||
]
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
for (const lp of exitNodeLoginPages) {
|
||||
if (!lp.domainId) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!lp.fullDomain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (lp.certificateStatus !== "valid") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// auth-allowed:
|
||||
// rule: "Host(`auth.pangolin.internal`) && (PathRegexp(`^/auth/resource/[0-9]+$`) || PathPrefix(`/_next`))"
|
||||
// service: next-service
|
||||
// entryPoints:
|
||||
// - websecure
|
||||
|
||||
const routerName = `loginpage-${lp.loginPageId}`;
|
||||
const fullDomain = `${lp.fullDomain}`;
|
||||
|
||||
if (!config_output.http.routers) {
|
||||
config_output.http.routers = {};
|
||||
}
|
||||
|
||||
config_output.http.routers![routerName + "-router"] = {
|
||||
entryPoints: [
|
||||
config.getRawConfig().traefik.https_entrypoint
|
||||
],
|
||||
service: "landing-service",
|
||||
rule: `Host(\`${fullDomain}\`) && (PathRegexp(\`^/auth/resource/[^/]+$\`) || PathRegexp(\`^/auth/idp/[0-9]+/oidc/callback\`) || PathPrefix(\`/_next\`) || Path(\`/auth/org\`) || PathRegexp(\`^/__nextjs*\`))`,
|
||||
priority: 203,
|
||||
tls: {}
|
||||
};
|
||||
|
||||
// auth-catchall:
|
||||
// rule: "Host(`auth.example.com`)"
|
||||
// middlewares:
|
||||
// - redirect-to-root
|
||||
// service: next-service
|
||||
// entryPoints:
|
||||
// - web
|
||||
|
||||
config_output.http.routers![routerName + "-catchall"] = {
|
||||
entryPoints: [
|
||||
config.getRawConfig().traefik.https_entrypoint
|
||||
],
|
||||
middlewares: [redirectToRootMiddlewareName],
|
||||
service: "landing-service",
|
||||
rule: `Host(\`${fullDomain}\`)`,
|
||||
priority: 202,
|
||||
tls: {}
|
||||
};
|
||||
|
||||
// we need to add a redirect from http to https too
|
||||
config_output.http.routers![routerName + "-redirect"] = {
|
||||
entryPoints: [
|
||||
config.getRawConfig().traefik.http_entrypoint
|
||||
],
|
||||
middlewares: [redirectHttpsMiddlewareName],
|
||||
service: "landing-service",
|
||||
rule: `Host(\`${fullDomain}\`)`,
|
||||
priority: 201
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return config_output;
|
||||
}
|
||||
14
server/private/lib/traefik/index.ts
Normal file
14
server/private/lib/traefik/index.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./getTraefikConfig";
|
||||
Reference in New Issue
Block a user