mirror of
https://github.com/fosrl/pangolin.git
synced 2026-02-08 05:56:38 +00:00
Merge branch 'dev' into feat/login-page-customization
This commit is contained in:
@@ -11,11 +11,14 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import {
|
||||
encodeHexLowerCase,
|
||||
} from "@oslojs/encoding";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { RemoteExitNode, remoteExitNodes, remoteExitNodeSessions, RemoteExitNodeSession } from "@server/db";
|
||||
import {
|
||||
RemoteExitNode,
|
||||
remoteExitNodes,
|
||||
remoteExitNodeSessions,
|
||||
RemoteExitNodeSession
|
||||
} from "@server/db";
|
||||
import { db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
@@ -23,30 +26,39 @@ export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
|
||||
|
||||
export async function createRemoteExitNodeSession(
|
||||
token: string,
|
||||
remoteExitNodeId: string,
|
||||
remoteExitNodeId: string
|
||||
): Promise<RemoteExitNodeSession> {
|
||||
const sessionId = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(token)),
|
||||
sha256(new TextEncoder().encode(token))
|
||||
);
|
||||
const session: RemoteExitNodeSession = {
|
||||
sessionId: sessionId,
|
||||
remoteExitNodeId,
|
||||
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
|
||||
expiresAt: new Date(Date.now() + EXPIRES).getTime()
|
||||
};
|
||||
await db.insert(remoteExitNodeSessions).values(session);
|
||||
return session;
|
||||
}
|
||||
|
||||
export async function validateRemoteExitNodeSessionToken(
|
||||
token: string,
|
||||
token: string
|
||||
): Promise<SessionValidationResult> {
|
||||
const sessionId = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(token)),
|
||||
sha256(new TextEncoder().encode(token))
|
||||
);
|
||||
const result = await db
|
||||
.select({ remoteExitNode: remoteExitNodes, session: remoteExitNodeSessions })
|
||||
.select({
|
||||
remoteExitNode: remoteExitNodes,
|
||||
session: remoteExitNodeSessions
|
||||
})
|
||||
.from(remoteExitNodeSessions)
|
||||
.innerJoin(remoteExitNodes, eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodes.remoteExitNodeId))
|
||||
.innerJoin(
|
||||
remoteExitNodes,
|
||||
eq(
|
||||
remoteExitNodeSessions.remoteExitNodeId,
|
||||
remoteExitNodes.remoteExitNodeId
|
||||
)
|
||||
)
|
||||
.where(eq(remoteExitNodeSessions.sessionId, sessionId));
|
||||
if (result.length < 1) {
|
||||
return { session: null, remoteExitNode: null };
|
||||
@@ -58,26 +70,32 @@ export async function validateRemoteExitNodeSessionToken(
|
||||
.where(eq(remoteExitNodeSessions.sessionId, session.sessionId));
|
||||
return { session: null, remoteExitNode: null };
|
||||
}
|
||||
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
|
||||
session.expiresAt = new Date(
|
||||
Date.now() + EXPIRES,
|
||||
).getTime();
|
||||
if (Date.now() >= session.expiresAt - EXPIRES / 2) {
|
||||
session.expiresAt = new Date(Date.now() + EXPIRES).getTime();
|
||||
await db
|
||||
.update(remoteExitNodeSessions)
|
||||
.set({
|
||||
expiresAt: session.expiresAt,
|
||||
expiresAt: session.expiresAt
|
||||
})
|
||||
.where(eq(remoteExitNodeSessions.sessionId, session.sessionId));
|
||||
}
|
||||
return { session, remoteExitNode };
|
||||
}
|
||||
|
||||
export async function invalidateRemoteExitNodeSession(sessionId: string): Promise<void> {
|
||||
await db.delete(remoteExitNodeSessions).where(eq(remoteExitNodeSessions.sessionId, sessionId));
|
||||
export async function invalidateRemoteExitNodeSession(
|
||||
sessionId: string
|
||||
): Promise<void> {
|
||||
await db
|
||||
.delete(remoteExitNodeSessions)
|
||||
.where(eq(remoteExitNodeSessions.sessionId, sessionId));
|
||||
}
|
||||
|
||||
export async function invalidateAllRemoteExitNodeSessions(remoteExitNodeId: string): Promise<void> {
|
||||
await db.delete(remoteExitNodeSessions).where(eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodeId));
|
||||
export async function invalidateAllRemoteExitNodeSessions(
|
||||
remoteExitNodeId: string
|
||||
): Promise<void> {
|
||||
await db
|
||||
.delete(remoteExitNodeSessions)
|
||||
.where(eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodeId));
|
||||
}
|
||||
|
||||
export type SessionValidationResult =
|
||||
|
||||
@@ -25,4 +25,4 @@ export async function initCleanup() {
|
||||
// Handle process termination
|
||||
process.on("SIGTERM", () => cleanup());
|
||||
process.on("SIGINT", () => cleanup());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,4 +12,4 @@
|
||||
*/
|
||||
|
||||
export * from "./getOrgTierData";
|
||||
export * from "./createCustomer";
|
||||
export * from "./createCustomer";
|
||||
|
||||
@@ -55,7 +55,6 @@ export async function getValidCertificatesForDomains(
|
||||
domains: Set<string>,
|
||||
useCache: boolean = true
|
||||
): Promise<Array<CertificateResult>> {
|
||||
|
||||
loadEncryptData(); // Ensure encryption key is loaded
|
||||
|
||||
const finalResults: CertificateResult[] = [];
|
||||
|
||||
@@ -12,14 +12,7 @@
|
||||
*/
|
||||
|
||||
import { build } from "@server/build";
|
||||
import {
|
||||
db,
|
||||
Org,
|
||||
orgs,
|
||||
ResourceSession,
|
||||
sessions,
|
||||
users
|
||||
} from "@server/db";
|
||||
import { db, Org, orgs, ResourceSession, sessions, users } from "@server/db";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import license from "#private/license/license";
|
||||
|
||||
@@ -83,9 +83,6 @@ export class PrivateConfig {
|
||||
? this.rawPrivateConfig.branding?.logo?.navbar?.height.toString()
|
||||
: undefined;
|
||||
|
||||
process.env.BRANDING_FAVICON_PATH =
|
||||
this.rawPrivateConfig.branding?.favicon_path;
|
||||
|
||||
process.env.BRANDING_APP_NAME =
|
||||
this.rawPrivateConfig.branding?.app_name || "Pangolin";
|
||||
|
||||
@@ -95,13 +92,9 @@ export class PrivateConfig {
|
||||
);
|
||||
}
|
||||
|
||||
process.env.LOGIN_PAGE_TITLE_TEXT =
|
||||
this.rawPrivateConfig.branding?.login_page?.title_text || "";
|
||||
process.env.LOGIN_PAGE_SUBTITLE_TEXT =
|
||||
this.rawPrivateConfig.branding?.login_page?.subtitle_text || "";
|
||||
|
||||
process.env.SIGNUP_PAGE_TITLE_TEXT =
|
||||
this.rawPrivateConfig.branding?.signup_page?.title_text || "";
|
||||
process.env.SIGNUP_PAGE_SUBTITLE_TEXT =
|
||||
this.rawPrivateConfig.branding?.signup_page?.subtitle_text || "";
|
||||
|
||||
|
||||
@@ -66,7 +66,9 @@ export async function sendToExitNode(
|
||||
// logger.debug(`Configured local exit node name: ${config.getRawConfig().gerbil.exit_node_name}`);
|
||||
|
||||
if (exitNode.name == config.getRawConfig().gerbil.exit_node_name) {
|
||||
hostname = privateConfig.getRawPrivateConfig().gerbil.local_exit_node_reachable_at;
|
||||
hostname =
|
||||
privateConfig.getRawPrivateConfig().gerbil
|
||||
.local_exit_node_reachable_at;
|
||||
}
|
||||
|
||||
if (!hostname) {
|
||||
|
||||
@@ -44,43 +44,53 @@ async function checkExitNodeOnlineStatus(
|
||||
const delayBetweenAttempts = 100; // 100ms delay between starting each attempt
|
||||
|
||||
// Create promises for all attempts with staggered delays
|
||||
const attemptPromises = Array.from({ length: maxAttempts }, async (_, index) => {
|
||||
const attemptNumber = index + 1;
|
||||
|
||||
// Add delay before each attempt (except the first)
|
||||
if (index > 0) {
|
||||
await new Promise((resolve) => setTimeout(resolve, delayBetweenAttempts * index));
|
||||
}
|
||||
const attemptPromises = Array.from(
|
||||
{ length: maxAttempts },
|
||||
async (_, index) => {
|
||||
const attemptNumber = index + 1;
|
||||
|
||||
try {
|
||||
const response = await axios.get(`http://${endpoint}/ping`, {
|
||||
timeout: timeoutMs,
|
||||
validateStatus: (status) => status === 200
|
||||
});
|
||||
|
||||
if (response.status === 200) {
|
||||
logger.debug(
|
||||
`Exit node ${endpoint} is online (attempt ${attemptNumber}/${maxAttempts})`
|
||||
// Add delay before each attempt (except the first)
|
||||
if (index > 0) {
|
||||
await new Promise((resolve) =>
|
||||
setTimeout(resolve, delayBetweenAttempts * index)
|
||||
);
|
||||
return { success: true, attemptNumber };
|
||||
}
|
||||
return { success: false, attemptNumber, error: 'Non-200 status' };
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : "Unknown error";
|
||||
logger.debug(
|
||||
`Exit node ${endpoint} ping failed (attempt ${attemptNumber}/${maxAttempts}): ${errorMessage}`
|
||||
);
|
||||
return { success: false, attemptNumber, error: errorMessage };
|
||||
|
||||
try {
|
||||
const response = await axios.get(`http://${endpoint}/ping`, {
|
||||
timeout: timeoutMs,
|
||||
validateStatus: (status) => status === 200
|
||||
});
|
||||
|
||||
if (response.status === 200) {
|
||||
logger.debug(
|
||||
`Exit node ${endpoint} is online (attempt ${attemptNumber}/${maxAttempts})`
|
||||
);
|
||||
return { success: true, attemptNumber };
|
||||
}
|
||||
return {
|
||||
success: false,
|
||||
attemptNumber,
|
||||
error: "Non-200 status"
|
||||
};
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error";
|
||||
logger.debug(
|
||||
`Exit node ${endpoint} ping failed (attempt ${attemptNumber}/${maxAttempts}): ${errorMessage}`
|
||||
);
|
||||
return { success: false, attemptNumber, error: errorMessage };
|
||||
}
|
||||
}
|
||||
});
|
||||
);
|
||||
|
||||
try {
|
||||
// Wait for the first successful response or all to fail
|
||||
const results = await Promise.allSettled(attemptPromises);
|
||||
|
||||
|
||||
// Check if any attempt succeeded
|
||||
for (const result of results) {
|
||||
if (result.status === 'fulfilled' && result.value.success) {
|
||||
if (result.status === "fulfilled" && result.value.success) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -137,7 +147,11 @@ export async function verifyExitNodeOrgAccess(
|
||||
return { hasAccess: false, exitNode };
|
||||
}
|
||||
|
||||
export async function listExitNodes(orgId: string, filterOnline = false, noCloud = false) {
|
||||
export async function listExitNodes(
|
||||
orgId: string,
|
||||
filterOnline = false,
|
||||
noCloud = false
|
||||
) {
|
||||
const allExitNodes = await db
|
||||
.select({
|
||||
exitNodeId: exitNodes.exitNodeId,
|
||||
@@ -166,7 +180,10 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
|
||||
eq(exitNodes.type, "gerbil"),
|
||||
or(
|
||||
// only choose nodes that are in the same region
|
||||
eq(exitNodes.region, config.getRawPrivateConfig().app.region),
|
||||
eq(
|
||||
exitNodes.region,
|
||||
config.getRawPrivateConfig().app.region
|
||||
),
|
||||
isNull(exitNodes.region) // or for enterprise where region is not set
|
||||
)
|
||||
),
|
||||
@@ -191,7 +208,7 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
|
||||
// let online: boolean;
|
||||
// if (filterOnline && node.type == "remoteExitNode") {
|
||||
// try {
|
||||
// const isActuallyOnline = await checkExitNodeOnlineStatus(
|
||||
// const isActuallyOnline = await checkExitNodeOnlineStatus(
|
||||
// node.endpoint
|
||||
// );
|
||||
|
||||
@@ -225,7 +242,8 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
|
||||
node.type === "remoteExitNode" && (!filterOnline || node.online)
|
||||
);
|
||||
const gerbilExitNodes = allExitNodes.filter(
|
||||
(node) => node.type === "gerbil" && (!filterOnline || node.online) && !noCloud
|
||||
(node) =>
|
||||
node.type === "gerbil" && (!filterOnline || node.online) && !noCloud
|
||||
);
|
||||
|
||||
// THIS PROVIDES THE FALL
|
||||
@@ -334,7 +352,11 @@ export function selectBestExitNode(
|
||||
return fallbackNode;
|
||||
}
|
||||
|
||||
export async function checkExitNodeOrg(exitNodeId: number, orgId: string, trx: Transaction | typeof db = db) {
|
||||
export async function checkExitNodeOrg(
|
||||
exitNodeId: number,
|
||||
orgId: string,
|
||||
trx: Transaction | typeof db = db
|
||||
) {
|
||||
const [exitNodeOrg] = await trx
|
||||
.select()
|
||||
.from(exitNodeOrgs)
|
||||
|
||||
@@ -12,4 +12,4 @@
|
||||
*/
|
||||
|
||||
export * from "./exitNodeComms";
|
||||
export * from "./exitNodes";
|
||||
export * from "./exitNodes";
|
||||
|
||||
@@ -177,7 +177,9 @@ export class LockManager {
|
||||
const exists = value !== null;
|
||||
const ownedByMe =
|
||||
exists &&
|
||||
value!.startsWith(`${config.getRawConfig().gerbil.exit_node_name}:`);
|
||||
value!.startsWith(
|
||||
`${config.getRawConfig().gerbil.exit_node_name}:`
|
||||
);
|
||||
const owner = exists ? value!.split(":")[0] : undefined;
|
||||
|
||||
return {
|
||||
|
||||
@@ -16,6 +16,7 @@ import { getCountryCodeForIp } from "@server/lib/geoip";
|
||||
import logger from "@server/logger";
|
||||
import { and, eq, lt } from "drizzle-orm";
|
||||
import cache from "@server/lib/cache";
|
||||
import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs";
|
||||
|
||||
async function getAccessDays(orgId: string): Promise<number> {
|
||||
// check cache first
|
||||
@@ -47,9 +48,7 @@ async function getAccessDays(orgId: string): Promise<number> {
|
||||
}
|
||||
|
||||
export async function cleanUpOldLogs(orgId: string, retentionDays: number) {
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
|
||||
const cutoffTimestamp = now - retentionDays * 24 * 60 * 60;
|
||||
const cutoffTimestamp = calculateCutoffTimestamp(retentionDays);
|
||||
|
||||
try {
|
||||
await db
|
||||
|
||||
@@ -14,15 +14,15 @@
|
||||
// Simple test file for the rate limit service with Redis
|
||||
// Run with: npx ts-node rateLimitService.test.ts
|
||||
|
||||
import { RateLimitService } from './rateLimit';
|
||||
import { RateLimitService } from "./rateLimit";
|
||||
|
||||
function generateClientId() {
|
||||
return 'client-' + Math.random().toString(36).substring(2, 15);
|
||||
return "client-" + Math.random().toString(36).substring(2, 15);
|
||||
}
|
||||
|
||||
async function runTests() {
|
||||
console.log('Starting Rate Limit Service Tests...\n');
|
||||
|
||||
console.log("Starting Rate Limit Service Tests...\n");
|
||||
|
||||
const rateLimitService = new RateLimitService();
|
||||
let testsPassed = 0;
|
||||
let testsTotal = 0;
|
||||
@@ -47,36 +47,54 @@ async function runTests() {
|
||||
}
|
||||
|
||||
// Test 1: Basic rate limiting
|
||||
await test('Should allow requests under the limit', async () => {
|
||||
await test("Should allow requests under the limit", async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 5;
|
||||
|
||||
for (let i = 0; i < maxRequests - 1; i++) {
|
||||
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
const result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
undefined,
|
||||
maxRequests
|
||||
);
|
||||
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
|
||||
assert(result.totalHits === i + 1, `Expected ${i + 1} hits, got ${result.totalHits}`);
|
||||
assert(
|
||||
result.totalHits === i + 1,
|
||||
`Expected ${i + 1} hits, got ${result.totalHits}`
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
// Test 2: Rate limit blocking
|
||||
await test('Should block requests over the limit', async () => {
|
||||
await test("Should block requests over the limit", async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 30;
|
||||
|
||||
// Use up all allowed requests
|
||||
for (let i = 0; i < maxRequests - 1; i++) {
|
||||
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
const result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
undefined,
|
||||
maxRequests
|
||||
);
|
||||
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
|
||||
}
|
||||
|
||||
// Next request should be blocked
|
||||
const blockedResult = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(blockedResult.isLimited, 'Request should be blocked');
|
||||
assert(blockedResult.reason === 'global', 'Should be blocked for global reason');
|
||||
const blockedResult = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
undefined,
|
||||
maxRequests
|
||||
);
|
||||
assert(blockedResult.isLimited, "Request should be blocked");
|
||||
assert(
|
||||
blockedResult.reason === "global",
|
||||
"Should be blocked for global reason"
|
||||
);
|
||||
});
|
||||
|
||||
// Test 3: Message type limits
|
||||
await test('Should handle message type limits', async () => {
|
||||
await test("Should handle message type limits", async () => {
|
||||
const clientId = generateClientId();
|
||||
const globalMax = 10;
|
||||
const messageTypeMax = 2;
|
||||
@@ -84,54 +102,64 @@ async function runTests() {
|
||||
// Send messages of type 'ping' up to the limit
|
||||
for (let i = 0; i < messageTypeMax - 1; i++) {
|
||||
const result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
'ping',
|
||||
globalMax,
|
||||
clientId,
|
||||
"ping",
|
||||
globalMax,
|
||||
messageTypeMax
|
||||
);
|
||||
assert(!result.isLimited, `Ping message ${i + 1} should be allowed`);
|
||||
assert(
|
||||
!result.isLimited,
|
||||
`Ping message ${i + 1} should be allowed`
|
||||
);
|
||||
}
|
||||
|
||||
// Next 'ping' should be blocked
|
||||
const blockedResult = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
'ping',
|
||||
globalMax,
|
||||
clientId,
|
||||
"ping",
|
||||
globalMax,
|
||||
messageTypeMax
|
||||
);
|
||||
assert(blockedResult.isLimited, 'Ping message should be blocked');
|
||||
assert(blockedResult.reason === 'message_type:ping', 'Should be blocked for message type');
|
||||
assert(blockedResult.isLimited, "Ping message should be blocked");
|
||||
assert(
|
||||
blockedResult.reason === "message_type:ping",
|
||||
"Should be blocked for message type"
|
||||
);
|
||||
|
||||
// Other message types should still work
|
||||
const otherResult = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
'pong',
|
||||
globalMax,
|
||||
clientId,
|
||||
"pong",
|
||||
globalMax,
|
||||
messageTypeMax
|
||||
);
|
||||
assert(!otherResult.isLimited, 'Pong message should be allowed');
|
||||
assert(!otherResult.isLimited, "Pong message should be allowed");
|
||||
});
|
||||
|
||||
// Test 4: Reset functionality
|
||||
await test('Should reset client correctly', async () => {
|
||||
await test("Should reset client correctly", async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 3;
|
||||
|
||||
// Use up some requests
|
||||
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
await rateLimitService.checkRateLimit(clientId, 'test', maxRequests);
|
||||
await rateLimitService.checkRateLimit(clientId, "test", maxRequests);
|
||||
|
||||
// Reset the client
|
||||
await rateLimitService.resetKey(clientId);
|
||||
|
||||
// Should be able to make fresh requests
|
||||
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(!result.isLimited, 'Request after reset should be allowed');
|
||||
assert(result.totalHits === 1, 'Should have 1 hit after reset');
|
||||
const result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
undefined,
|
||||
maxRequests
|
||||
);
|
||||
assert(!result.isLimited, "Request after reset should be allowed");
|
||||
assert(result.totalHits === 1, "Should have 1 hit after reset");
|
||||
});
|
||||
|
||||
// Test 5: Different clients are independent
|
||||
await test('Should handle different clients independently', async () => {
|
||||
await test("Should handle different clients independently", async () => {
|
||||
const client1 = generateClientId();
|
||||
const client2 = generateClientId();
|
||||
const maxRequests = 2;
|
||||
@@ -139,43 +167,62 @@ async function runTests() {
|
||||
// Client 1 uses up their limit
|
||||
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
|
||||
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
|
||||
const client1Blocked = await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
|
||||
assert(client1Blocked.isLimited, 'Client 1 should be blocked');
|
||||
const client1Blocked = await rateLimitService.checkRateLimit(
|
||||
client1,
|
||||
undefined,
|
||||
maxRequests
|
||||
);
|
||||
assert(client1Blocked.isLimited, "Client 1 should be blocked");
|
||||
|
||||
// Client 2 should still be able to make requests
|
||||
const client2Result = await rateLimitService.checkRateLimit(client2, undefined, maxRequests);
|
||||
assert(!client2Result.isLimited, 'Client 2 should not be blocked');
|
||||
assert(client2Result.totalHits === 1, 'Client 2 should have 1 hit');
|
||||
const client2Result = await rateLimitService.checkRateLimit(
|
||||
client2,
|
||||
undefined,
|
||||
maxRequests
|
||||
);
|
||||
assert(!client2Result.isLimited, "Client 2 should not be blocked");
|
||||
assert(client2Result.totalHits === 1, "Client 2 should have 1 hit");
|
||||
});
|
||||
|
||||
// Test 6: Decrement functionality
|
||||
await test('Should decrement correctly', async () => {
|
||||
await test("Should decrement correctly", async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 5;
|
||||
|
||||
// Make some requests
|
||||
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
let result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(result.totalHits === 3, 'Should have 3 hits before decrement');
|
||||
let result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
undefined,
|
||||
maxRequests
|
||||
);
|
||||
assert(result.totalHits === 3, "Should have 3 hits before decrement");
|
||||
|
||||
// Decrement
|
||||
await rateLimitService.decrementRateLimit(clientId);
|
||||
|
||||
// Next request should reflect the decrement
|
||||
result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(result.totalHits === 3, 'Should have 3 hits after decrement + increment');
|
||||
result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
undefined,
|
||||
maxRequests
|
||||
);
|
||||
assert(
|
||||
result.totalHits === 3,
|
||||
"Should have 3 hits after decrement + increment"
|
||||
);
|
||||
});
|
||||
|
||||
// Wait a moment for any pending Redis operations
|
||||
console.log('\nWaiting for Redis sync...');
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
console.log("\nWaiting for Redis sync...");
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000));
|
||||
|
||||
// Force sync to test Redis integration
|
||||
await test('Should sync to Redis', async () => {
|
||||
await test("Should sync to Redis", async () => {
|
||||
await rateLimitService.forceSyncAllPendingData();
|
||||
// If this doesn't throw, Redis sync is working
|
||||
assert(true, 'Redis sync completed');
|
||||
assert(true, "Redis sync completed");
|
||||
});
|
||||
|
||||
// Cleanup
|
||||
@@ -185,18 +232,18 @@ async function runTests() {
|
||||
console.log(`\n--- Test Results ---`);
|
||||
console.log(`✅ Passed: ${testsPassed}/${testsTotal}`);
|
||||
console.log(`❌ Failed: ${testsTotal - testsPassed}/${testsTotal}`);
|
||||
|
||||
|
||||
if (testsPassed === testsTotal) {
|
||||
console.log('\n🎉 All tests passed!');
|
||||
console.log("\n🎉 All tests passed!");
|
||||
process.exit(0);
|
||||
} else {
|
||||
console.log('\n💥 Some tests failed!');
|
||||
console.log("\n💥 Some tests failed!");
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Run the tests
|
||||
runTests().catch(error => {
|
||||
console.error('Test runner error:', error);
|
||||
runTests().catch((error) => {
|
||||
console.error("Test runner error:", error);
|
||||
process.exit(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -40,7 +40,8 @@ interface RateLimitResult {
|
||||
|
||||
export class RateLimitService {
|
||||
private localRateLimitTracker: Map<string, RateLimitTracker> = new Map();
|
||||
private localMessageTypeRateLimitTracker: Map<string, RateLimitTracker> = new Map();
|
||||
private localMessageTypeRateLimitTracker: Map<string, RateLimitTracker> =
|
||||
new Map();
|
||||
private cleanupInterval: NodeJS.Timeout | null = null;
|
||||
private forceSyncInterval: NodeJS.Timeout | null = null;
|
||||
|
||||
@@ -68,12 +69,18 @@ export class RateLimitService {
|
||||
return `ratelimit:${clientId}`;
|
||||
}
|
||||
|
||||
private getMessageTypeRateLimitKey(clientId: string, messageType: string): string {
|
||||
private getMessageTypeRateLimitKey(
|
||||
clientId: string,
|
||||
messageType: string
|
||||
): string {
|
||||
return `ratelimit:${clientId}:${messageType}`;
|
||||
}
|
||||
|
||||
// Helper function to clean up old timestamp fields from a Redis hash
|
||||
private async cleanupOldTimestamps(key: string, windowStart: number): Promise<void> {
|
||||
private async cleanupOldTimestamps(
|
||||
key: string,
|
||||
windowStart: number
|
||||
): Promise<void> {
|
||||
if (!redisManager.isRedisEnabled()) return;
|
||||
|
||||
try {
|
||||
@@ -101,10 +108,15 @@ export class RateLimitService {
|
||||
const batch = fieldsToDelete.slice(i, i + batchSize);
|
||||
await client.hdel(key, ...batch);
|
||||
}
|
||||
logger.debug(`Cleaned up ${fieldsToDelete.length} old timestamp fields from ${key}`);
|
||||
logger.debug(
|
||||
`Cleaned up ${fieldsToDelete.length} old timestamp fields from ${key}`
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to cleanup old timestamps for key ${key}:`, error);
|
||||
logger.error(
|
||||
`Failed to cleanup old timestamps for key ${key}:`,
|
||||
error
|
||||
);
|
||||
// Don't throw - cleanup failures shouldn't block rate limiting
|
||||
}
|
||||
}
|
||||
@@ -114,7 +126,8 @@ export class RateLimitService {
|
||||
clientId: string,
|
||||
tracker: RateLimitTracker
|
||||
): Promise<void> {
|
||||
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
|
||||
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0)
|
||||
return;
|
||||
|
||||
try {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
@@ -132,7 +145,11 @@ export class RateLimitService {
|
||||
const newValue = (
|
||||
parseInt(currentValue || "0") + tracker.pendingCount
|
||||
).toString();
|
||||
await redisManager.hset(globalKey, currentTime.toString(), newValue);
|
||||
await redisManager.hset(
|
||||
globalKey,
|
||||
currentTime.toString(),
|
||||
newValue
|
||||
);
|
||||
|
||||
// Set TTL using the client directly - this prevents the key from persisting forever
|
||||
if (redisManager.getClient()) {
|
||||
@@ -145,7 +162,9 @@ export class RateLimitService {
|
||||
tracker.lastSyncedCount = tracker.count;
|
||||
tracker.pendingCount = 0;
|
||||
|
||||
logger.debug(`Synced global rate limit to Redis for client ${clientId}`);
|
||||
logger.debug(
|
||||
`Synced global rate limit to Redis for client ${clientId}`
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to sync global rate limit to Redis:", error);
|
||||
}
|
||||
@@ -156,12 +175,16 @@ export class RateLimitService {
|
||||
messageType: string,
|
||||
tracker: RateLimitTracker
|
||||
): Promise<void> {
|
||||
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
|
||||
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0)
|
||||
return;
|
||||
|
||||
try {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const windowStart = currentTime - RATE_LIMIT_WINDOW;
|
||||
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
|
||||
const messageTypeKey = this.getMessageTypeRateLimitKey(
|
||||
clientId,
|
||||
messageType
|
||||
);
|
||||
|
||||
// Clean up old timestamp fields before writing
|
||||
await this.cleanupOldTimestamps(messageTypeKey, windowStart);
|
||||
@@ -195,12 +218,17 @@ export class RateLimitService {
|
||||
`Synced message type rate limit to Redis for client ${clientId}, type ${messageType}`
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to sync message type rate limit to Redis:", error);
|
||||
logger.error(
|
||||
"Failed to sync message type rate limit to Redis:",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize local tracker from Redis data
|
||||
private async initializeLocalTracker(clientId: string): Promise<RateLimitTracker> {
|
||||
private async initializeLocalTracker(
|
||||
clientId: string
|
||||
): Promise<RateLimitTracker> {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const windowStart = currentTime - RATE_LIMIT_WINDOW;
|
||||
|
||||
@@ -215,14 +243,16 @@ export class RateLimitService {
|
||||
|
||||
try {
|
||||
const globalKey = this.getRateLimitKey(clientId);
|
||||
|
||||
|
||||
// Clean up old timestamp fields before reading
|
||||
await this.cleanupOldTimestamps(globalKey, windowStart);
|
||||
|
||||
|
||||
const globalRateLimitData = await redisManager.hgetall(globalKey);
|
||||
|
||||
let count = 0;
|
||||
for (const [timestamp, countStr] of Object.entries(globalRateLimitData)) {
|
||||
for (const [timestamp, countStr] of Object.entries(
|
||||
globalRateLimitData
|
||||
)) {
|
||||
const time = parseInt(timestamp);
|
||||
if (time >= windowStart) {
|
||||
count += parseInt(countStr);
|
||||
@@ -236,7 +266,10 @@ export class RateLimitService {
|
||||
lastSyncedCount: count
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error("Failed to initialize global tracker from Redis:", error);
|
||||
logger.error(
|
||||
"Failed to initialize global tracker from Redis:",
|
||||
error
|
||||
);
|
||||
return {
|
||||
count: 0,
|
||||
windowStart: currentTime,
|
||||
@@ -263,15 +296,21 @@ export class RateLimitService {
|
||||
}
|
||||
|
||||
try {
|
||||
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
|
||||
|
||||
const messageTypeKey = this.getMessageTypeRateLimitKey(
|
||||
clientId,
|
||||
messageType
|
||||
);
|
||||
|
||||
// Clean up old timestamp fields before reading
|
||||
await this.cleanupOldTimestamps(messageTypeKey, windowStart);
|
||||
|
||||
const messageTypeRateLimitData = await redisManager.hgetall(messageTypeKey);
|
||||
|
||||
const messageTypeRateLimitData =
|
||||
await redisManager.hgetall(messageTypeKey);
|
||||
|
||||
let count = 0;
|
||||
for (const [timestamp, countStr] of Object.entries(messageTypeRateLimitData)) {
|
||||
for (const [timestamp, countStr] of Object.entries(
|
||||
messageTypeRateLimitData
|
||||
)) {
|
||||
const time = parseInt(timestamp);
|
||||
if (time >= windowStart) {
|
||||
count += parseInt(countStr);
|
||||
@@ -285,7 +324,10 @@ export class RateLimitService {
|
||||
lastSyncedCount: count
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error("Failed to initialize message type tracker from Redis:", error);
|
||||
logger.error(
|
||||
"Failed to initialize message type tracker from Redis:",
|
||||
error
|
||||
);
|
||||
return {
|
||||
count: 0,
|
||||
windowStart: currentTime,
|
||||
@@ -327,7 +369,10 @@ export class RateLimitService {
|
||||
isLimited: true,
|
||||
reason: "global",
|
||||
totalHits: globalTracker.count,
|
||||
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
|
||||
resetTime: new Date(
|
||||
(globalTracker.windowStart + Math.floor(windowMs / 1000)) *
|
||||
1000
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
@@ -339,19 +384,32 @@ export class RateLimitService {
|
||||
// Check message type specific rate limit if messageType is provided
|
||||
if (messageType) {
|
||||
const messageTypeKey = `${clientId}:${messageType}`;
|
||||
let messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
|
||||
let messageTypeTracker =
|
||||
this.localMessageTypeRateLimitTracker.get(messageTypeKey);
|
||||
|
||||
if (!messageTypeTracker || messageTypeTracker.windowStart < windowStart) {
|
||||
if (
|
||||
!messageTypeTracker ||
|
||||
messageTypeTracker.windowStart < windowStart
|
||||
) {
|
||||
// New window or first request for this message type - initialize from Redis if available
|
||||
messageTypeTracker = await this.initializeMessageTypeTracker(clientId, messageType);
|
||||
messageTypeTracker = await this.initializeMessageTypeTracker(
|
||||
clientId,
|
||||
messageType
|
||||
);
|
||||
messageTypeTracker.windowStart = currentTime;
|
||||
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
|
||||
this.localMessageTypeRateLimitTracker.set(
|
||||
messageTypeKey,
|
||||
messageTypeTracker
|
||||
);
|
||||
}
|
||||
|
||||
// Increment message type counters
|
||||
messageTypeTracker.count++;
|
||||
messageTypeTracker.pendingCount++;
|
||||
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
|
||||
this.localMessageTypeRateLimitTracker.set(
|
||||
messageTypeKey,
|
||||
messageTypeTracker
|
||||
);
|
||||
|
||||
// Check if message type limit would be exceeded
|
||||
if (messageTypeTracker.count >= messageTypeLimit) {
|
||||
@@ -359,25 +417,38 @@ export class RateLimitService {
|
||||
isLimited: true,
|
||||
reason: `message_type:${messageType}`,
|
||||
totalHits: messageTypeTracker.count,
|
||||
resetTime: new Date((messageTypeTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
|
||||
resetTime: new Date(
|
||||
(messageTypeTracker.windowStart +
|
||||
Math.floor(windowMs / 1000)) *
|
||||
1000
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
// Sync to Redis if threshold reached
|
||||
if (messageTypeTracker.pendingCount >= REDIS_SYNC_THRESHOLD) {
|
||||
this.syncMessageTypeRateLimitToRedis(clientId, messageType, messageTypeTracker);
|
||||
this.syncMessageTypeRateLimitToRedis(
|
||||
clientId,
|
||||
messageType,
|
||||
messageTypeTracker
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isLimited: false,
|
||||
totalHits: globalTracker.count,
|
||||
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
|
||||
resetTime: new Date(
|
||||
(globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
// Decrement function for skipSuccessfulRequests/skipFailedRequests functionality
|
||||
async decrementRateLimit(clientId: string, messageType?: string): Promise<void> {
|
||||
async decrementRateLimit(
|
||||
clientId: string,
|
||||
messageType?: string
|
||||
): Promise<void> {
|
||||
// Decrement global counter
|
||||
const globalTracker = this.localRateLimitTracker.get(clientId);
|
||||
if (globalTracker && globalTracker.count > 0) {
|
||||
@@ -389,7 +460,8 @@ export class RateLimitService {
|
||||
// Decrement message type counter if provided
|
||||
if (messageType) {
|
||||
const messageTypeKey = `${clientId}:${messageType}`;
|
||||
const messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
|
||||
const messageTypeTracker =
|
||||
this.localMessageTypeRateLimitTracker.get(messageTypeKey);
|
||||
if (messageTypeTracker && messageTypeTracker.count > 0) {
|
||||
messageTypeTracker.count--;
|
||||
messageTypeTracker.pendingCount--;
|
||||
@@ -401,7 +473,7 @@ export class RateLimitService {
|
||||
async resetKey(clientId: string): Promise<void> {
|
||||
// Remove from local tracking
|
||||
this.localRateLimitTracker.delete(clientId);
|
||||
|
||||
|
||||
// Remove all message type entries for this client
|
||||
for (const [key] of this.localMessageTypeRateLimitTracker) {
|
||||
if (key.startsWith(`${clientId}:`)) {
|
||||
@@ -417,9 +489,13 @@ export class RateLimitService {
|
||||
// Get all message type keys for this client and delete them
|
||||
const client = redisManager.getClient();
|
||||
if (client) {
|
||||
const messageTypeKeys = await client.keys(`ratelimit:${clientId}:*`);
|
||||
const messageTypeKeys = await client.keys(
|
||||
`ratelimit:${clientId}:*`
|
||||
);
|
||||
if (messageTypeKeys.length > 0) {
|
||||
await Promise.all(messageTypeKeys.map(key => redisManager.del(key)));
|
||||
await Promise.all(
|
||||
messageTypeKeys.map((key) => redisManager.del(key))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -431,7 +507,10 @@ export class RateLimitService {
|
||||
const windowStart = currentTime - RATE_LIMIT_WINDOW;
|
||||
|
||||
// Clean up global rate limit tracking and sync pending data
|
||||
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
|
||||
for (const [
|
||||
clientId,
|
||||
tracker
|
||||
] of this.localRateLimitTracker.entries()) {
|
||||
if (tracker.windowStart < windowStart) {
|
||||
// Sync any pending data before cleanup
|
||||
if (tracker.pendingCount > 0) {
|
||||
@@ -442,12 +521,19 @@ export class RateLimitService {
|
||||
}
|
||||
|
||||
// Clean up message type rate limit tracking and sync pending data
|
||||
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
|
||||
for (const [
|
||||
key,
|
||||
tracker
|
||||
] of this.localMessageTypeRateLimitTracker.entries()) {
|
||||
if (tracker.windowStart < windowStart) {
|
||||
// Sync any pending data before cleanup
|
||||
if (tracker.pendingCount > 0) {
|
||||
const [clientId, messageType] = key.split(":", 2);
|
||||
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
|
||||
await this.syncMessageTypeRateLimitToRedis(
|
||||
clientId,
|
||||
messageType,
|
||||
tracker
|
||||
);
|
||||
}
|
||||
this.localMessageTypeRateLimitTracker.delete(key);
|
||||
}
|
||||
@@ -461,17 +547,27 @@ export class RateLimitService {
|
||||
logger.debug("Force syncing all pending rate limit data to Redis...");
|
||||
|
||||
// Sync all pending global rate limits
|
||||
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
|
||||
for (const [
|
||||
clientId,
|
||||
tracker
|
||||
] of this.localRateLimitTracker.entries()) {
|
||||
if (tracker.pendingCount > 0) {
|
||||
await this.syncRateLimitToRedis(clientId, tracker);
|
||||
}
|
||||
}
|
||||
|
||||
// Sync all pending message type rate limits
|
||||
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
|
||||
for (const [
|
||||
key,
|
||||
tracker
|
||||
] of this.localMessageTypeRateLimitTracker.entries()) {
|
||||
if (tracker.pendingCount > 0) {
|
||||
const [clientId, messageType] = key.split(":", 2);
|
||||
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
|
||||
await this.syncMessageTypeRateLimitToRedis(
|
||||
clientId,
|
||||
messageType,
|
||||
tracker
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -504,4 +600,4 @@ export class RateLimitService {
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const rateLimitService = new RateLimitService();
|
||||
export const rateLimitService = new RateLimitService();
|
||||
|
||||
@@ -17,7 +17,10 @@ import { MemoryStore, Store } from "express-rate-limit";
|
||||
import RedisStore from "#private/lib/redisStore";
|
||||
|
||||
export function createStore(): Store {
|
||||
if (build != "oss" && privateConfig.getRawPrivateConfig().flags.enable_redis) {
|
||||
if (
|
||||
build != "oss" &&
|
||||
privateConfig.getRawPrivateConfig().flags.enable_redis
|
||||
) {
|
||||
const rateLimitStore: Store = new RedisStore({
|
||||
prefix: "api-rate-limit", // Optional: customize Redis key prefix
|
||||
skipFailedRequests: true, // Don't count failed requests
|
||||
|
||||
@@ -115,7 +115,6 @@ export const privateConfigSchema = z.object({
|
||||
.optional()
|
||||
})
|
||||
.optional(),
|
||||
favicon_path: z.string().optional(),
|
||||
footer: z
|
||||
.array(
|
||||
z.object({
|
||||
@@ -127,14 +126,12 @@ export const privateConfigSchema = z.object({
|
||||
hide_auth_layout_footer: z.boolean().optional().default(false),
|
||||
login_page: z
|
||||
.object({
|
||||
subtitle_text: z.string().optional(),
|
||||
title_text: z.string().optional()
|
||||
subtitle_text: z.string().optional()
|
||||
})
|
||||
.optional(),
|
||||
signup_page: z
|
||||
.object({
|
||||
subtitle_text: z.string().optional(),
|
||||
title_text: z.string().optional()
|
||||
subtitle_text: z.string().optional()
|
||||
})
|
||||
.optional(),
|
||||
resource_auth_page: z
|
||||
|
||||
@@ -19,7 +19,7 @@ import { build } from "@server/build";
|
||||
class RedisManager {
|
||||
public client: Redis | null = null;
|
||||
private writeClient: Redis | null = null; // Master for writes
|
||||
private readClient: Redis | null = null; // Replica for reads
|
||||
private readClient: Redis | null = null; // Replica for reads
|
||||
private subscriber: Redis | null = null;
|
||||
private publisher: Redis | null = null;
|
||||
private isEnabled: boolean = false;
|
||||
@@ -46,7 +46,8 @@ class RedisManager {
|
||||
this.isEnabled = false;
|
||||
return;
|
||||
}
|
||||
this.isEnabled = privateConfig.getRawPrivateConfig().flags.enable_redis || false;
|
||||
this.isEnabled =
|
||||
privateConfig.getRawPrivateConfig().flags.enable_redis || false;
|
||||
if (this.isEnabled) {
|
||||
this.initializeClients();
|
||||
}
|
||||
@@ -63,15 +64,19 @@ class RedisManager {
|
||||
}
|
||||
|
||||
private async triggerReconnectionCallbacks(): Promise<void> {
|
||||
logger.info(`Triggering ${this.reconnectionCallbacks.size} reconnection callbacks`);
|
||||
|
||||
const promises = Array.from(this.reconnectionCallbacks).map(async (callback) => {
|
||||
try {
|
||||
await callback();
|
||||
} catch (error) {
|
||||
logger.error("Error in reconnection callback:", error);
|
||||
logger.info(
|
||||
`Triggering ${this.reconnectionCallbacks.size} reconnection callbacks`
|
||||
);
|
||||
|
||||
const promises = Array.from(this.reconnectionCallbacks).map(
|
||||
async (callback) => {
|
||||
try {
|
||||
await callback();
|
||||
} catch (error) {
|
||||
logger.error("Error in reconnection callback:", error);
|
||||
}
|
||||
}
|
||||
});
|
||||
);
|
||||
|
||||
await Promise.allSettled(promises);
|
||||
}
|
||||
@@ -79,13 +84,17 @@ class RedisManager {
|
||||
private async resubscribeToChannels(): Promise<void> {
|
||||
if (!this.subscriber || this.subscribers.size === 0) return;
|
||||
|
||||
logger.info(`Re-subscribing to ${this.subscribers.size} channels after Redis reconnection`);
|
||||
|
||||
logger.info(
|
||||
`Re-subscribing to ${this.subscribers.size} channels after Redis reconnection`
|
||||
);
|
||||
|
||||
try {
|
||||
const channels = Array.from(this.subscribers.keys());
|
||||
if (channels.length > 0) {
|
||||
await this.subscriber.subscribe(...channels);
|
||||
logger.info(`Successfully re-subscribed to channels: ${channels.join(', ')}`);
|
||||
logger.info(
|
||||
`Successfully re-subscribed to channels: ${channels.join(", ")}`
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Failed to re-subscribe to channels:", error);
|
||||
@@ -98,7 +107,7 @@ class RedisManager {
|
||||
host: redisConfig.host!,
|
||||
port: redisConfig.port!,
|
||||
password: redisConfig.password,
|
||||
db: redisConfig.db,
|
||||
db: redisConfig.db
|
||||
// tls: {
|
||||
// rejectUnauthorized:
|
||||
// redisConfig.tls?.reject_unauthorized || false
|
||||
@@ -112,7 +121,7 @@ class RedisManager {
|
||||
if (!redisConfig.replicas || redisConfig.replicas.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
// Use the first replica for simplicity
|
||||
// In production, you might want to implement load balancing across replicas
|
||||
const replica = redisConfig.replicas[0];
|
||||
@@ -120,7 +129,7 @@ class RedisManager {
|
||||
host: replica.host!,
|
||||
port: replica.port!,
|
||||
password: replica.password,
|
||||
db: replica.db || redisConfig.db,
|
||||
db: replica.db || redisConfig.db
|
||||
// tls: {
|
||||
// rejectUnauthorized:
|
||||
// replica.tls?.reject_unauthorized || false
|
||||
@@ -133,7 +142,7 @@ class RedisManager {
|
||||
private initializeClients(): void {
|
||||
const masterConfig = this.getRedisConfig();
|
||||
const replicaConfig = this.getReplicaRedisConfig();
|
||||
|
||||
|
||||
this.hasReplicas = replicaConfig !== null;
|
||||
|
||||
try {
|
||||
@@ -144,7 +153,7 @@ class RedisManager {
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
commandTimeout: this.commandTimeout
|
||||
});
|
||||
|
||||
// Initialize replica connection for reads (if available)
|
||||
@@ -155,7 +164,7 @@ class RedisManager {
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
commandTimeout: this.commandTimeout
|
||||
});
|
||||
} else {
|
||||
// Fallback to master for reads if no replicas
|
||||
@@ -172,7 +181,7 @@ class RedisManager {
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
commandTimeout: this.commandTimeout
|
||||
});
|
||||
|
||||
// Subscriber uses replica if available (reads)
|
||||
@@ -182,7 +191,7 @@ class RedisManager {
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
commandTimeout: this.commandTimeout
|
||||
});
|
||||
|
||||
// Add reconnection handlers for write client
|
||||
@@ -202,11 +211,14 @@ class RedisManager {
|
||||
logger.info("Redis write client ready");
|
||||
this.isWriteHealthy = true;
|
||||
this.updateOverallHealth();
|
||||
|
||||
|
||||
// Trigger reconnection callbacks when Redis comes back online
|
||||
if (this.isHealthy) {
|
||||
this.triggerReconnectionCallbacks().catch(error => {
|
||||
logger.error("Error triggering reconnection callbacks:", error);
|
||||
this.triggerReconnectionCallbacks().catch((error) => {
|
||||
logger.error(
|
||||
"Error triggering reconnection callbacks:",
|
||||
error
|
||||
);
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -233,11 +245,14 @@ class RedisManager {
|
||||
logger.info("Redis read client ready");
|
||||
this.isReadHealthy = true;
|
||||
this.updateOverallHealth();
|
||||
|
||||
|
||||
// Trigger reconnection callbacks when Redis comes back online
|
||||
if (this.isHealthy) {
|
||||
this.triggerReconnectionCallbacks().catch(error => {
|
||||
logger.error("Error triggering reconnection callbacks:", error);
|
||||
this.triggerReconnectionCallbacks().catch((error) => {
|
||||
logger.error(
|
||||
"Error triggering reconnection callbacks:",
|
||||
error
|
||||
);
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -298,8 +313,8 @@ class RedisManager {
|
||||
}
|
||||
);
|
||||
|
||||
const setupMessage = this.hasReplicas
|
||||
? "Redis clients initialized successfully with replica support"
|
||||
const setupMessage = this.hasReplicas
|
||||
? "Redis clients initialized successfully with replica support"
|
||||
: "Redis clients initialized successfully (single instance)";
|
||||
logger.info(setupMessage);
|
||||
|
||||
@@ -313,7 +328,8 @@ class RedisManager {
|
||||
|
||||
private updateOverallHealth(): void {
|
||||
// Overall health is true if write is healthy and (read is healthy OR we don't have replicas)
|
||||
this.isHealthy = this.isWriteHealthy && (this.isReadHealthy || !this.hasReplicas);
|
||||
this.isHealthy =
|
||||
this.isWriteHealthy && (this.isReadHealthy || !this.hasReplicas);
|
||||
}
|
||||
|
||||
private async executeWithRetry<T>(
|
||||
@@ -322,49 +338,61 @@ class RedisManager {
|
||||
fallbackOperation?: () => Promise<T>
|
||||
): Promise<T> {
|
||||
let lastError: Error | null = null;
|
||||
|
||||
|
||||
for (let attempt = 0; attempt <= this.maxRetries; attempt++) {
|
||||
try {
|
||||
return await operation();
|
||||
} catch (error) {
|
||||
lastError = error as Error;
|
||||
|
||||
|
||||
// If this is the last attempt, try fallback if available
|
||||
if (attempt === this.maxRetries && fallbackOperation) {
|
||||
try {
|
||||
logger.warn(`${operationName} primary operation failed, trying fallback`);
|
||||
logger.warn(
|
||||
`${operationName} primary operation failed, trying fallback`
|
||||
);
|
||||
return await fallbackOperation();
|
||||
} catch (fallbackError) {
|
||||
logger.error(`${operationName} fallback also failed:`, fallbackError);
|
||||
logger.error(
|
||||
`${operationName} fallback also failed:`,
|
||||
fallbackError
|
||||
);
|
||||
throw lastError;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Don't retry on the last attempt
|
||||
if (attempt === this.maxRetries) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Calculate delay with exponential backoff
|
||||
const delay = Math.min(
|
||||
this.baseRetryDelay * Math.pow(this.backoffMultiplier, attempt),
|
||||
this.baseRetryDelay *
|
||||
Math.pow(this.backoffMultiplier, attempt),
|
||||
this.maxRetryDelay
|
||||
);
|
||||
|
||||
logger.warn(`${operationName} failed (attempt ${attempt + 1}/${this.maxRetries + 1}), retrying in ${delay}ms:`, error);
|
||||
|
||||
|
||||
logger.warn(
|
||||
`${operationName} failed (attempt ${attempt + 1}/${this.maxRetries + 1}), retrying in ${delay}ms:`,
|
||||
error
|
||||
);
|
||||
|
||||
// Wait before retrying
|
||||
await new Promise(resolve => setTimeout(resolve, delay));
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
}
|
||||
}
|
||||
|
||||
logger.error(`${operationName} failed after ${this.maxRetries + 1} attempts:`, lastError);
|
||||
|
||||
logger.error(
|
||||
`${operationName} failed after ${this.maxRetries + 1} attempts:`,
|
||||
lastError
|
||||
);
|
||||
throw lastError;
|
||||
}
|
||||
|
||||
private startHealthMonitoring(): void {
|
||||
if (!this.isEnabled) return;
|
||||
|
||||
|
||||
// Check health every 30 seconds
|
||||
setInterval(async () => {
|
||||
try {
|
||||
@@ -381,7 +409,7 @@ class RedisManager {
|
||||
|
||||
private async checkRedisHealth(): Promise<boolean> {
|
||||
const now = Date.now();
|
||||
|
||||
|
||||
// Only check health every 30 seconds
|
||||
if (now - this.lastHealthCheck < this.healthCheckInterval) {
|
||||
return this.isHealthy;
|
||||
@@ -400,24 +428,45 @@ class RedisManager {
|
||||
// Check write client (master) health
|
||||
await Promise.race([
|
||||
this.writeClient.ping(),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Write client health check timeout')), 2000)
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(
|
||||
() =>
|
||||
reject(
|
||||
new Error("Write client health check timeout")
|
||||
),
|
||||
2000
|
||||
)
|
||||
)
|
||||
]);
|
||||
this.isWriteHealthy = true;
|
||||
|
||||
// Check read client health if it's different from write client
|
||||
if (this.hasReplicas && this.readClient && this.readClient !== this.writeClient) {
|
||||
if (
|
||||
this.hasReplicas &&
|
||||
this.readClient &&
|
||||
this.readClient !== this.writeClient
|
||||
) {
|
||||
try {
|
||||
await Promise.race([
|
||||
this.readClient.ping(),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Read client health check timeout')), 2000)
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(
|
||||
() =>
|
||||
reject(
|
||||
new Error(
|
||||
"Read client health check timeout"
|
||||
)
|
||||
),
|
||||
2000
|
||||
)
|
||||
)
|
||||
]);
|
||||
this.isReadHealthy = true;
|
||||
} catch (error) {
|
||||
logger.error("Redis read client health check failed:", error);
|
||||
logger.error(
|
||||
"Redis read client health check failed:",
|
||||
error
|
||||
);
|
||||
this.isReadHealthy = false;
|
||||
}
|
||||
} else {
|
||||
@@ -475,16 +524,13 @@ class RedisManager {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
async () => {
|
||||
if (ttl) {
|
||||
await this.writeClient!.setex(key, ttl, value);
|
||||
} else {
|
||||
await this.writeClient!.set(key, value);
|
||||
}
|
||||
},
|
||||
"Redis SET"
|
||||
);
|
||||
await this.executeWithRetry(async () => {
|
||||
if (ttl) {
|
||||
await this.writeClient!.setex(key, ttl, value);
|
||||
} else {
|
||||
await this.writeClient!.set(key, value);
|
||||
}
|
||||
}, "Redis SET");
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis SET error:", error);
|
||||
@@ -496,9 +542,10 @@ class RedisManager {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return null;
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.get(key)
|
||||
: undefined;
|
||||
const fallbackOperation =
|
||||
this.hasReplicas && this.writeClient && this.isWriteHealthy
|
||||
? () => this.writeClient!.get(key)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.get(key),
|
||||
@@ -560,9 +607,10 @@ class RedisManager {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return [];
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.smembers(key)
|
||||
: undefined;
|
||||
const fallbackOperation =
|
||||
this.hasReplicas && this.writeClient && this.isWriteHealthy
|
||||
? () => this.writeClient!.smembers(key)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.smembers(key),
|
||||
@@ -598,9 +646,10 @@ class RedisManager {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return null;
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.hget(key, field)
|
||||
: undefined;
|
||||
const fallbackOperation =
|
||||
this.hasReplicas && this.writeClient && this.isWriteHealthy
|
||||
? () => this.writeClient!.hget(key, field)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.hget(key, field),
|
||||
@@ -632,9 +681,10 @@ class RedisManager {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return {};
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.hgetall(key)
|
||||
: undefined;
|
||||
const fallbackOperation =
|
||||
this.hasReplicas && this.writeClient && this.isWriteHealthy
|
||||
? () => this.writeClient!.hgetall(key)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.hgetall(key),
|
||||
@@ -658,18 +708,18 @@ class RedisManager {
|
||||
}
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
async () => {
|
||||
// Add timeout to prevent hanging
|
||||
return Promise.race([
|
||||
this.publisher!.publish(channel, message),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Redis publish timeout')), 3000)
|
||||
await this.executeWithRetry(async () => {
|
||||
// Add timeout to prevent hanging
|
||||
return Promise.race([
|
||||
this.publisher!.publish(channel, message),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(
|
||||
() => reject(new Error("Redis publish timeout")),
|
||||
3000
|
||||
)
|
||||
]);
|
||||
},
|
||||
"Redis PUBLISH"
|
||||
);
|
||||
)
|
||||
]);
|
||||
}, "Redis PUBLISH");
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis PUBLISH error:", error);
|
||||
@@ -689,17 +739,20 @@ class RedisManager {
|
||||
if (!this.subscribers.has(channel)) {
|
||||
this.subscribers.set(channel, new Set());
|
||||
// Only subscribe to the channel if it's the first subscriber
|
||||
await this.executeWithRetry(
|
||||
async () => {
|
||||
return Promise.race([
|
||||
this.subscriber!.subscribe(channel),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Redis subscribe timeout')), 5000)
|
||||
await this.executeWithRetry(async () => {
|
||||
return Promise.race([
|
||||
this.subscriber!.subscribe(channel),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(
|
||||
() =>
|
||||
reject(
|
||||
new Error("Redis subscribe timeout")
|
||||
),
|
||||
5000
|
||||
)
|
||||
]);
|
||||
},
|
||||
"Redis SUBSCRIBE"
|
||||
);
|
||||
)
|
||||
]);
|
||||
}, "Redis SUBSCRIBE");
|
||||
}
|
||||
|
||||
this.subscribers.get(channel)!.add(callback);
|
||||
|
||||
@@ -11,9 +11,9 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Store, Options, IncrementResponse } from 'express-rate-limit';
|
||||
import { rateLimitService } from './rateLimit';
|
||||
import logger from '@server/logger';
|
||||
import { Store, Options, IncrementResponse } from "express-rate-limit";
|
||||
import { rateLimitService } from "./rateLimit";
|
||||
import logger from "@server/logger";
|
||||
|
||||
/**
|
||||
* A Redis-backed rate limiting store for express-rate-limit that optimizes
|
||||
@@ -57,12 +57,14 @@ export default class RedisStore implements Store {
|
||||
*
|
||||
* @param options - Configuration options for the store.
|
||||
*/
|
||||
constructor(options: {
|
||||
prefix?: string;
|
||||
skipFailedRequests?: boolean;
|
||||
skipSuccessfulRequests?: boolean;
|
||||
} = {}) {
|
||||
this.prefix = options.prefix || 'express-rate-limit';
|
||||
constructor(
|
||||
options: {
|
||||
prefix?: string;
|
||||
skipFailedRequests?: boolean;
|
||||
skipSuccessfulRequests?: boolean;
|
||||
} = {}
|
||||
) {
|
||||
this.prefix = options.prefix || "express-rate-limit";
|
||||
this.skipFailedRequests = options.skipFailedRequests || false;
|
||||
this.skipSuccessfulRequests = options.skipSuccessfulRequests || false;
|
||||
}
|
||||
@@ -101,7 +103,8 @@ export default class RedisStore implements Store {
|
||||
|
||||
return {
|
||||
totalHits: result.totalHits || 1,
|
||||
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
|
||||
resetTime:
|
||||
result.resetTime || new Date(Date.now() + this.windowMs)
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`RedisStore increment error for key ${key}:`, error);
|
||||
@@ -158,7 +161,9 @@ export default class RedisStore implements Store {
|
||||
*/
|
||||
async resetAll(): Promise<void> {
|
||||
try {
|
||||
logger.warn('RedisStore resetAll called - this operation can be expensive');
|
||||
logger.warn(
|
||||
"RedisStore resetAll called - this operation can be expensive"
|
||||
);
|
||||
|
||||
// Force sync all pending data first
|
||||
await rateLimitService.forceSyncAllPendingData();
|
||||
@@ -167,9 +172,9 @@ export default class RedisStore implements Store {
|
||||
// scanning all Redis keys with our prefix, which could be expensive.
|
||||
// In production, it's better to let entries expire naturally.
|
||||
|
||||
logger.info('RedisStore resetAll completed (pending data synced)');
|
||||
logger.info("RedisStore resetAll completed (pending data synced)");
|
||||
} catch (error) {
|
||||
logger.error('RedisStore resetAll error:', error);
|
||||
logger.error("RedisStore resetAll error:", error);
|
||||
// Don't throw - this is an optional method
|
||||
}
|
||||
}
|
||||
@@ -181,7 +186,9 @@ export default class RedisStore implements Store {
|
||||
* @param key - The identifier for a client.
|
||||
* @returns Current hit count and reset time, or null if no data exists.
|
||||
*/
|
||||
async getHits(key: string): Promise<{ totalHits: number; resetTime: Date } | null> {
|
||||
async getHits(
|
||||
key: string
|
||||
): Promise<{ totalHits: number; resetTime: Date } | null> {
|
||||
try {
|
||||
const clientId = `${this.prefix}:${key}`;
|
||||
|
||||
@@ -200,7 +207,8 @@ export default class RedisStore implements Store {
|
||||
|
||||
return {
|
||||
totalHits: Math.max(0, (result.totalHits || 0) - 1), // Adjust for the decrement
|
||||
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
|
||||
resetTime:
|
||||
result.resetTime || new Date(Date.now() + this.windowMs)
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`RedisStore getHits error for key ${key}:`, error);
|
||||
@@ -215,9 +223,9 @@ export default class RedisStore implements Store {
|
||||
async shutdown(): Promise<void> {
|
||||
try {
|
||||
// The rateLimitService handles its own cleanup
|
||||
logger.info('RedisStore shutdown completed');
|
||||
logger.info("RedisStore shutdown completed");
|
||||
} catch (error) {
|
||||
logger.error('RedisStore shutdown error:', error);
|
||||
logger.error("RedisStore shutdown error:", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,10 +16,10 @@ import privateConfig from "#private/lib/config";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export enum AudienceIds {
|
||||
SignUps = "6c4e77b2-0851-4bd6-bac8-f51f91360f1a",
|
||||
Subscribed = "870b43fd-387f-44de-8fc1-707335f30b20",
|
||||
Churned = "f3ae92bd-2fdb-4d77-8746-2118afd62549",
|
||||
Newsletter = "5500c431-191c-42f0-a5d4-8b6d445b4ea0"
|
||||
SignUps = "6c4e77b2-0851-4bd6-bac8-f51f91360f1a",
|
||||
Subscribed = "870b43fd-387f-44de-8fc1-707335f30b20",
|
||||
Churned = "f3ae92bd-2fdb-4d77-8746-2118afd62549",
|
||||
Newsletter = "5500c431-191c-42f0-a5d4-8b6d445b4ea0"
|
||||
}
|
||||
|
||||
const resend = new Resend(
|
||||
@@ -33,7 +33,9 @@ export async function moveEmailToAudience(
|
||||
audienceId: AudienceIds
|
||||
) {
|
||||
if (process.env.ENVIRONMENT !== "prod") {
|
||||
logger.debug(`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`);
|
||||
logger.debug(
|
||||
`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`
|
||||
);
|
||||
return;
|
||||
}
|
||||
const { error, data } = await retryWithBackoff(async () => {
|
||||
|
||||
@@ -189,7 +189,7 @@ export async function getTraefikConfig(
|
||||
);
|
||||
|
||||
if (!validation.isValid) {
|
||||
logger.error(
|
||||
logger.debug(
|
||||
`Invalid path rewrite configuration for resource ${resourceId}: ${validation.error}`
|
||||
);
|
||||
return;
|
||||
@@ -823,7 +823,7 @@ export async function getTraefikConfig(
|
||||
(cert) => cert.queriedDomain === lp.fullDomain
|
||||
);
|
||||
if (!matchingCert) {
|
||||
logger.warn(
|
||||
logger.debug(
|
||||
`No matching certificate found for login page domain: ${lp.fullDomain}`
|
||||
);
|
||||
continue;
|
||||
|
||||
@@ -11,4 +11,4 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./getTraefikConfig";
|
||||
export * from "./getTraefikConfig";
|
||||
|
||||
@@ -64,11 +64,14 @@ export class License {
|
||||
private validationServerUrl = `${this.serverBaseUrl}/api/v1/license/enterprise/validate`;
|
||||
private activationServerUrl = `${this.serverBaseUrl}/api/v1/license/enterprise/activate`;
|
||||
|
||||
private statusCache = new NodeCache({ stdTTL: this.phoneHomeInterval });
|
||||
private statusCache = new NodeCache();
|
||||
private licenseKeyCache = new NodeCache();
|
||||
|
||||
private statusKey = "status";
|
||||
private serverSecret!: string;
|
||||
private phoneHomeFailureCount = 0;
|
||||
private checkInProgress = false;
|
||||
private doRecheck = false;
|
||||
|
||||
private publicKey = `-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAx9RKc8cw+G8r7h/xeozF
|
||||
@@ -81,12 +84,11 @@ LQIDAQAB
|
||||
-----END PUBLIC KEY-----`;
|
||||
|
||||
constructor(private hostMeta: HostMeta) {
|
||||
setInterval(
|
||||
async () => {
|
||||
await this.check();
|
||||
},
|
||||
1000 * 60 * 60
|
||||
);
|
||||
setInterval(async () => {
|
||||
this.doRecheck = true;
|
||||
await this.check();
|
||||
this.doRecheck = false;
|
||||
}, 1000 * this.phoneHomeInterval);
|
||||
}
|
||||
|
||||
public listKeys(): LicenseKeyCache[] {
|
||||
@@ -103,6 +105,7 @@ LQIDAQAB
|
||||
public async forceRecheck() {
|
||||
this.statusCache.flushAll();
|
||||
this.licenseKeyCache.flushAll();
|
||||
this.phoneHomeFailureCount = 0;
|
||||
|
||||
return await this.check();
|
||||
}
|
||||
@@ -118,24 +121,49 @@ LQIDAQAB
|
||||
}
|
||||
|
||||
public async check(): Promise<LicenseStatus> {
|
||||
// If a check is already in progress, return the last known status
|
||||
if (this.checkInProgress) {
|
||||
logger.debug(
|
||||
"License check already in progress, returning last known status"
|
||||
);
|
||||
const lastStatus = this.statusCache.get(this.statusKey) as
|
||||
| LicenseStatus
|
||||
| undefined;
|
||||
if (lastStatus) {
|
||||
return lastStatus;
|
||||
}
|
||||
// If no cached status exists, return default status
|
||||
return {
|
||||
hostId: this.hostMeta.hostMetaId,
|
||||
isHostLicensed: true,
|
||||
isLicenseValid: false
|
||||
};
|
||||
}
|
||||
|
||||
const status: LicenseStatus = {
|
||||
hostId: this.hostMeta.hostMetaId,
|
||||
isHostLicensed: true,
|
||||
isLicenseValid: false
|
||||
};
|
||||
|
||||
this.checkInProgress = true;
|
||||
|
||||
try {
|
||||
if (this.statusCache.has(this.statusKey)) {
|
||||
if (!this.doRecheck && this.statusCache.has(this.statusKey)) {
|
||||
const res = this.statusCache.get("status") as LicenseStatus;
|
||||
return res;
|
||||
}
|
||||
// Invalidate all
|
||||
this.licenseKeyCache.flushAll();
|
||||
logger.debug("Checking license status...");
|
||||
// Build new cache in temporary Map before invalidating old cache
|
||||
const newCache = new Map<string, LicenseKeyCache>();
|
||||
|
||||
const allKeysRes = await db.select().from(licenseKey);
|
||||
|
||||
if (allKeysRes.length === 0) {
|
||||
status.isHostLicensed = false;
|
||||
// Invalidate all and set new cache (empty)
|
||||
this.licenseKeyCache.flushAll();
|
||||
this.statusCache.set(this.statusKey, status);
|
||||
return status;
|
||||
}
|
||||
|
||||
@@ -158,7 +186,7 @@ LQIDAQAB
|
||||
this.publicKey
|
||||
);
|
||||
|
||||
this.licenseKeyCache.set<LicenseKeyCache>(decryptedKey, {
|
||||
newCache.set(decryptedKey, {
|
||||
licenseKey: decryptedKey,
|
||||
licenseKeyEncrypted: key.licenseKeyId,
|
||||
valid: payload.valid,
|
||||
@@ -177,14 +205,11 @@ LQIDAQAB
|
||||
);
|
||||
logger.error(e);
|
||||
|
||||
this.licenseKeyCache.set<LicenseKeyCache>(
|
||||
key.licenseKeyId,
|
||||
{
|
||||
licenseKey: key.licenseKeyId,
|
||||
licenseKeyEncrypted: key.licenseKeyId,
|
||||
valid: false
|
||||
}
|
||||
);
|
||||
newCache.set(key.licenseKeyId, {
|
||||
licenseKey: key.licenseKeyId,
|
||||
licenseKeyEncrypted: key.licenseKeyId,
|
||||
valid: false
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,17 +231,31 @@ LQIDAQAB
|
||||
if (!apiResponse?.success) {
|
||||
throw new Error(apiResponse?.error);
|
||||
}
|
||||
// Reset failure count on success
|
||||
this.phoneHomeFailureCount = 0;
|
||||
} catch (e) {
|
||||
logger.error("Error communicating with license server:");
|
||||
logger.error(e);
|
||||
this.phoneHomeFailureCount++;
|
||||
if (this.phoneHomeFailureCount === 1) {
|
||||
// First failure: fail silently
|
||||
logger.error("Error communicating with license server:");
|
||||
logger.error(e);
|
||||
logger.error(
|
||||
`Allowing failure. Will retry one more time at next run interval.`
|
||||
);
|
||||
// return last known good status
|
||||
return this.statusCache.get(
|
||||
this.statusKey
|
||||
) as LicenseStatus;
|
||||
} else {
|
||||
// Subsequent failures: fail abruptly
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
// Check and update all license keys with server response
|
||||
for (const key of keys) {
|
||||
try {
|
||||
const cached = this.licenseKeyCache.get<LicenseKeyCache>(
|
||||
key.licenseKey
|
||||
)!;
|
||||
const cached = newCache.get(key.licenseKey)!;
|
||||
const licenseKeyRes =
|
||||
apiResponse?.data?.licenseKeys[key.licenseKey];
|
||||
|
||||
@@ -240,10 +279,7 @@ LQIDAQAB
|
||||
`Can't trust license key: ${key.licenseKey}`
|
||||
);
|
||||
cached.valid = false;
|
||||
this.licenseKeyCache.set<LicenseKeyCache>(
|
||||
key.licenseKey,
|
||||
cached
|
||||
);
|
||||
newCache.set(key.licenseKey, cached);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -274,10 +310,7 @@ LQIDAQAB
|
||||
})
|
||||
.where(eq(licenseKey.licenseKeyId, encryptedKey));
|
||||
|
||||
this.licenseKeyCache.set<LicenseKeyCache>(
|
||||
key.licenseKey,
|
||||
cached
|
||||
);
|
||||
newCache.set(key.licenseKey, cached);
|
||||
} catch (e) {
|
||||
logger.error(`Error validating license key: ${key}`);
|
||||
logger.error(e);
|
||||
@@ -286,9 +319,7 @@ LQIDAQAB
|
||||
|
||||
// Compute host status
|
||||
for (const key of keys) {
|
||||
const cached = this.licenseKeyCache.get<LicenseKeyCache>(
|
||||
key.licenseKey
|
||||
)!;
|
||||
const cached = newCache.get(key.licenseKey)!;
|
||||
|
||||
if (cached.type === "host") {
|
||||
status.isLicenseValid = cached.valid;
|
||||
@@ -299,9 +330,17 @@ LQIDAQAB
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Invalidate old cache and set new cache
|
||||
this.licenseKeyCache.flushAll();
|
||||
for (const [key, value] of newCache.entries()) {
|
||||
this.licenseKeyCache.set<LicenseKeyCache>(key, value);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Error checking license status:");
|
||||
logger.error(error);
|
||||
} finally {
|
||||
this.checkInProgress = false;
|
||||
}
|
||||
|
||||
this.statusCache.set(this.statusKey, status);
|
||||
@@ -430,20 +469,58 @@ LQIDAQAB
|
||||
: key.instanceId
|
||||
}));
|
||||
|
||||
const response = await fetch(this.validationServerUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
body: JSON.stringify({
|
||||
licenseKeys: decryptedKeys,
|
||||
instanceName: this.hostMeta.hostMetaId
|
||||
})
|
||||
});
|
||||
const maxAttempts = 10;
|
||||
const initialRetryDelay = 1 * 1000; // 1 seconds
|
||||
const exponentialFactor = 1.2;
|
||||
|
||||
const data = await response.json();
|
||||
let lastError: Error | undefined;
|
||||
|
||||
return data as ValidateLicenseAPIResponse;
|
||||
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
|
||||
try {
|
||||
const response = await fetch(this.validationServerUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
body: JSON.stringify({
|
||||
licenseKeys: decryptedKeys,
|
||||
instanceName: this.hostMeta.hostMetaId
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data as ValidateLicenseAPIResponse;
|
||||
} catch (error) {
|
||||
lastError =
|
||||
error instanceof Error ? error : new Error(String(error));
|
||||
|
||||
if (attempt < maxAttempts) {
|
||||
// Calculate exponential backoff delay
|
||||
const retryDelay = Math.floor(
|
||||
initialRetryDelay *
|
||||
Math.pow(exponentialFactor, attempt - 1)
|
||||
);
|
||||
|
||||
logger.debug(
|
||||
`License validation request failed (attempt ${attempt}/${maxAttempts}), retrying in ${retryDelay} ms...`
|
||||
);
|
||||
await new Promise((resolve) =>
|
||||
setTimeout(resolve, retryDelay)
|
||||
);
|
||||
} else {
|
||||
logger.error(
|
||||
`License validation request failed after ${maxAttempts} attempts`
|
||||
);
|
||||
throw lastError;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError || new Error("License validation request failed");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,10 +19,7 @@ import * as crypto from "crypto";
|
||||
* @param publicKey - The public key used for verification (PEM format)
|
||||
* @returns The decoded payload if validation succeeds, throws an error otherwise
|
||||
*/
|
||||
function validateJWT<Payload>(
|
||||
token: string,
|
||||
publicKey: string
|
||||
): Payload {
|
||||
function validateJWT<Payload>(token: string, publicKey: string): Payload {
|
||||
// Split the JWT into its three parts
|
||||
const parts = token.split(".");
|
||||
if (parts.length !== 3) {
|
||||
|
||||
@@ -19,6 +19,7 @@ import { Request, Response, NextFunction } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { and, eq, lt } from "drizzle-orm";
|
||||
import cache from "@server/lib/cache";
|
||||
import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs";
|
||||
|
||||
async function getActionDays(orgId: string): Promise<number> {
|
||||
// check cache first
|
||||
@@ -40,15 +41,17 @@ async function getActionDays(orgId: string): Promise<number> {
|
||||
}
|
||||
|
||||
// store the result in cache
|
||||
cache.set(`org_${orgId}_actionDays`, org.settingsLogRetentionDaysAction, 300);
|
||||
cache.set(
|
||||
`org_${orgId}_actionDays`,
|
||||
org.settingsLogRetentionDaysAction,
|
||||
300
|
||||
);
|
||||
|
||||
return org.settingsLogRetentionDaysAction;
|
||||
}
|
||||
|
||||
export async function cleanUpOldLogs(orgId: string, retentionDays: number) {
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
|
||||
const cutoffTimestamp = now - retentionDays * 24 * 60 * 60;
|
||||
const cutoffTimestamp = calculateCutoffTimestamp(retentionDays);
|
||||
|
||||
try {
|
||||
await db
|
||||
@@ -142,4 +145,3 @@ export function logActionAudit(action: ActionsEnum) {
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,8 @@ export async function verifyCertificateAccess(
|
||||
try {
|
||||
// Assume user/org access is already verified
|
||||
const orgId = req.params.orgId;
|
||||
const certId = req.params.certId || req.body?.certId || req.query?.certId;
|
||||
const certId =
|
||||
req.params.certId || req.body?.certId || req.query?.certId;
|
||||
let domainId =
|
||||
req.params.domainId || req.body?.domainId || req.query?.domainId;
|
||||
|
||||
@@ -39,10 +40,12 @@ export async function verifyCertificateAccess(
|
||||
}
|
||||
|
||||
if (!domainId) {
|
||||
|
||||
if (!certId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Must provide either certId or domainId")
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Must provide either certId or domainId"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -75,7 +78,10 @@ export async function verifyCertificateAccess(
|
||||
|
||||
if (!domainId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Must provide either certId or domainId")
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Must provide either certId or domainId"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -24,8 +24,7 @@ export async function verifyIdpAccess(
|
||||
) {
|
||||
try {
|
||||
const userId = req.user!.userId;
|
||||
const idpId =
|
||||
req.params.idpId || req.body.idpId || req.query.idpId;
|
||||
const idpId = req.params.idpId || req.body.idpId || req.query.idpId;
|
||||
const orgId = req.params.orgId;
|
||||
|
||||
if (!userId) {
|
||||
@@ -50,9 +49,7 @@ export async function verifyIdpAccess(
|
||||
.select()
|
||||
.from(idp)
|
||||
.innerJoin(idpOrg, eq(idp.idpId, idpOrg.idpId))
|
||||
.where(
|
||||
and(eq(idp.idpId, idpId), eq(idpOrg.orgId, orgId))
|
||||
)
|
||||
.where(and(eq(idp.idpId, idpId), eq(idpOrg.orgId, orgId)))
|
||||
.limit(1);
|
||||
|
||||
if (!idpRes || !idpRes.idp || !idpRes.idpOrg) {
|
||||
|
||||
@@ -26,7 +26,8 @@ export const verifySessionRemoteExitNodeMiddleware = async (
|
||||
// get the token from the auth header
|
||||
const token = req.headers["authorization"]?.split(" ")[1] || "";
|
||||
|
||||
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
|
||||
const { session, remoteExitNode } =
|
||||
await validateRemoteExitNodeSessionToken(token);
|
||||
|
||||
if (!session || !remoteExitNode) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
|
||||
@@ -19,8 +19,14 @@ import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { queryAccessAuditLogsParams, queryAccessAuditLogsQuery, queryAccess } from "./queryAccessAuditLog";
|
||||
import {
|
||||
queryAccessAuditLogsParams,
|
||||
queryAccessAuditLogsQuery,
|
||||
queryAccess,
|
||||
countAccessQuery
|
||||
} from "./queryAccessAuditLog";
|
||||
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
|
||||
import { MAX_EXPORT_LIMIT } from "@server/routers/auditLogs";
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
@@ -61,16 +67,28 @@ export async function exportAccessAuditLogs(
|
||||
}
|
||||
|
||||
const data = { ...parsedQuery.data, ...parsedParams.data };
|
||||
const [{ count }] = await countAccessQuery(data);
|
||||
if (count > MAX_EXPORT_LIMIT) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Export limit exceeded. Your selection contains ${count} rows, but the maximum is ${MAX_EXPORT_LIMIT} rows. Please select a shorter time range to reduce the data.`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const baseQuery = queryAccess(data);
|
||||
|
||||
const log = await baseQuery.limit(data.limit).offset(data.offset);
|
||||
|
||||
const csvData = generateCSV(log);
|
||||
|
||||
res.setHeader('Content-Type', 'text/csv');
|
||||
res.setHeader('Content-Disposition', `attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`);
|
||||
|
||||
|
||||
res.setHeader("Content-Type", "text/csv");
|
||||
res.setHeader(
|
||||
"Content-Disposition",
|
||||
`attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`
|
||||
);
|
||||
|
||||
return res.send(csvData);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
@@ -78,4 +96,4 @@ export async function exportAccessAuditLogs(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,8 +19,14 @@ import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { queryActionAuditLogsParams, queryActionAuditLogsQuery, queryAction } from "./queryActionAuditLog";
|
||||
import {
|
||||
queryActionAuditLogsParams,
|
||||
queryActionAuditLogsQuery,
|
||||
queryAction,
|
||||
countActionQuery
|
||||
} from "./queryActionAuditLog";
|
||||
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
|
||||
import { MAX_EXPORT_LIMIT } from "@server/routers/auditLogs";
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
@@ -60,17 +66,29 @@ export async function exportActionAuditLogs(
|
||||
);
|
||||
}
|
||||
|
||||
const data = { ...parsedQuery.data, ...parsedParams.data };
|
||||
const data = { ...parsedQuery.data, ...parsedParams.data };
|
||||
const [{ count }] = await countActionQuery(data);
|
||||
if (count > MAX_EXPORT_LIMIT) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Export limit exceeded. Your selection contains ${count} rows, but the maximum is ${MAX_EXPORT_LIMIT} rows. Please select a shorter time range to reduce the data.`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const baseQuery = queryAction(data);
|
||||
|
||||
const log = await baseQuery.limit(data.limit).offset(data.offset);
|
||||
|
||||
const csvData = generateCSV(log);
|
||||
|
||||
res.setHeader('Content-Type', 'text/csv');
|
||||
res.setHeader('Content-Disposition', `attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`);
|
||||
|
||||
|
||||
res.setHeader("Content-Type", "text/csv");
|
||||
res.setHeader(
|
||||
"Content-Disposition",
|
||||
`attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`
|
||||
);
|
||||
|
||||
return res.send(csvData);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
@@ -78,4 +96,4 @@ export async function exportActionAuditLogs(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,4 +14,4 @@
|
||||
export * from "./queryActionAuditLog";
|
||||
export * from "./exportActionAuditLog";
|
||||
export * from "./queryAccessAuditLog";
|
||||
export * from "./exportAccessAuditLog";
|
||||
export * from "./exportAccessAuditLog";
|
||||
|
||||
@@ -24,6 +24,7 @@ import { fromError } from "zod-validation-error";
|
||||
import { QueryAccessAuditLogResponse } from "@server/routers/auditLogs/types";
|
||||
import response from "@server/lib/response";
|
||||
import logger from "@server/logger";
|
||||
import { getSevenDaysAgo } from "@app/lib/getSevenDaysAgo";
|
||||
|
||||
export const queryAccessAuditLogsQuery = z.object({
|
||||
// iso string just validate its a parseable date
|
||||
@@ -32,7 +33,14 @@ export const queryAccessAuditLogsQuery = z.object({
|
||||
.refine((val) => !isNaN(Date.parse(val)), {
|
||||
error: "timeStart must be a valid ISO date string"
|
||||
})
|
||||
.transform((val) => Math.floor(new Date(val).getTime() / 1000)),
|
||||
.transform((val) => Math.floor(new Date(val).getTime() / 1000))
|
||||
.prefault(() => getSevenDaysAgo().toISOString())
|
||||
.openapi({
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
description:
|
||||
"Start time as ISO date string (defaults to 7 days ago)"
|
||||
}),
|
||||
timeEnd: z
|
||||
.string()
|
||||
.refine((val) => !isNaN(Date.parse(val)), {
|
||||
@@ -44,7 +52,8 @@ export const queryAccessAuditLogsQuery = z.object({
|
||||
.openapi({
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
description: "End time as ISO date string (defaults to current time)"
|
||||
description:
|
||||
"End time as ISO date string (defaults to current time)"
|
||||
}),
|
||||
action: z
|
||||
.union([z.boolean(), z.string()])
|
||||
@@ -181,9 +190,15 @@ async function queryUniqueFilterAttributes(
|
||||
.where(baseConditions);
|
||||
|
||||
return {
|
||||
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
|
||||
resources: uniqueResources.filter((row): row is { id: number; name: string | null } => row.id !== null),
|
||||
locations: uniqueLocations.map(row => row.locations).filter((location): location is string => location !== null)
|
||||
actors: uniqueActors
|
||||
.map((row) => row.actor)
|
||||
.filter((actor): actor is string => actor !== null),
|
||||
resources: uniqueResources.filter(
|
||||
(row): row is { id: number; name: string | null } => row.id !== null
|
||||
),
|
||||
locations: uniqueLocations
|
||||
.map((row) => row.locations)
|
||||
.filter((location): location is string => location !== null)
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import { fromError } from "zod-validation-error";
|
||||
import { QueryActionAuditLogResponse } from "@server/routers/auditLogs/types";
|
||||
import response from "@server/lib/response";
|
||||
import logger from "@server/logger";
|
||||
import { getSevenDaysAgo } from "@app/lib/getSevenDaysAgo";
|
||||
|
||||
export const queryActionAuditLogsQuery = z.object({
|
||||
// iso string just validate its a parseable date
|
||||
@@ -32,7 +33,14 @@ export const queryActionAuditLogsQuery = z.object({
|
||||
.refine((val) => !isNaN(Date.parse(val)), {
|
||||
error: "timeStart must be a valid ISO date string"
|
||||
})
|
||||
.transform((val) => Math.floor(new Date(val).getTime() / 1000)),
|
||||
.transform((val) => Math.floor(new Date(val).getTime() / 1000))
|
||||
.prefault(() => getSevenDaysAgo().toISOString())
|
||||
.openapi({
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
description:
|
||||
"Start time as ISO date string (defaults to 7 days ago)"
|
||||
}),
|
||||
timeEnd: z
|
||||
.string()
|
||||
.refine((val) => !isNaN(Date.parse(val)), {
|
||||
@@ -44,7 +52,8 @@ export const queryActionAuditLogsQuery = z.object({
|
||||
.openapi({
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
description: "End time as ISO date string (defaults to current time)"
|
||||
description:
|
||||
"End time as ISO date string (defaults to current time)"
|
||||
}),
|
||||
action: z.string().optional(),
|
||||
actorType: z.string().optional(),
|
||||
@@ -68,8 +77,9 @@ export const queryActionAuditLogsParams = z.object({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export const queryActionAuditLogsCombined =
|
||||
queryActionAuditLogsQuery.merge(queryActionAuditLogsParams);
|
||||
export const queryActionAuditLogsCombined = queryActionAuditLogsQuery.merge(
|
||||
queryActionAuditLogsParams
|
||||
);
|
||||
type Q = z.infer<typeof queryActionAuditLogsCombined>;
|
||||
|
||||
function getWhere(data: Q) {
|
||||
@@ -78,7 +88,9 @@ function getWhere(data: Q) {
|
||||
lt(actionAuditLog.timestamp, data.timeEnd),
|
||||
eq(actionAuditLog.orgId, data.orgId),
|
||||
data.actor ? eq(actionAuditLog.actor, data.actor) : undefined,
|
||||
data.actorType ? eq(actionAuditLog.actorType, data.actorType) : undefined,
|
||||
data.actorType
|
||||
? eq(actionAuditLog.actorType, data.actorType)
|
||||
: undefined,
|
||||
data.actorId ? eq(actionAuditLog.actorId, data.actorId) : undefined,
|
||||
data.action ? eq(actionAuditLog.action, data.action) : undefined
|
||||
);
|
||||
@@ -135,8 +147,12 @@ async function queryUniqueFilterAttributes(
|
||||
.where(baseConditions);
|
||||
|
||||
return {
|
||||
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
|
||||
actions: uniqueActions.map(row => row.action).filter((action): action is string => action !== null),
|
||||
actors: uniqueActors
|
||||
.map((row) => row.actor)
|
||||
.filter((actor): actor is string => actor !== null),
|
||||
actions: uniqueActions
|
||||
.map((row) => row.action)
|
||||
.filter((action): action is string => action !== null)
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -13,4 +13,4 @@
|
||||
|
||||
export * from "./transferSession";
|
||||
export * from "./getSessionTransferToken";
|
||||
export * from "./quickStart";
|
||||
export * from "./quickStart";
|
||||
|
||||
@@ -395,7 +395,8 @@ export async function quickStart(
|
||||
.values({
|
||||
targetId: newTarget[0].targetId,
|
||||
hcEnabled: false
|
||||
}).returning();
|
||||
})
|
||||
.returning();
|
||||
|
||||
// add the new target to the targetIps array
|
||||
targetIps.push(`${ip}/32`);
|
||||
@@ -406,7 +407,12 @@ export async function quickStart(
|
||||
.where(eq(newts.siteId, siteId!))
|
||||
.limit(1);
|
||||
|
||||
await addTargets(newt.newtId, newTarget, newHealthcheck, resource.protocol);
|
||||
await addTargets(
|
||||
newt.newtId,
|
||||
newTarget,
|
||||
newHealthcheck,
|
||||
resource.protocol
|
||||
);
|
||||
|
||||
// Set resource pincode if provided
|
||||
if (pincode) {
|
||||
|
||||
@@ -26,8 +26,8 @@ import { getLineItems, getStandardFeaturePriceSet } from "@server/lib/billing";
|
||||
import { getTierPriceSet, TierId } from "@server/lib/billing/tiers";
|
||||
|
||||
const createCheckoutSessionSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export async function createCheckoutSession(
|
||||
req: Request,
|
||||
@@ -72,7 +72,7 @@ export async function createCheckoutSession(
|
||||
billing_address_collection: "required",
|
||||
line_items: [
|
||||
{
|
||||
price: standardTierPrice, // Use the standard tier
|
||||
price: standardTierPrice, // Use the standard tier
|
||||
quantity: 1
|
||||
},
|
||||
...getLineItems(getStandardFeaturePriceSet())
|
||||
|
||||
@@ -24,8 +24,8 @@ import { fromError } from "zod-validation-error";
|
||||
import stripe from "#private/lib/stripe";
|
||||
|
||||
const createPortalSessionSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export async function createPortalSession(
|
||||
req: Request,
|
||||
|
||||
@@ -34,8 +34,8 @@ import {
|
||||
} from "@server/db";
|
||||
|
||||
const getOrgSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
|
||||
@@ -28,8 +28,8 @@ import { FeatureId } from "@server/lib/billing";
|
||||
import { GetOrgUsageResponse } from "@server/routers/billing/types";
|
||||
|
||||
const getOrgSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
@@ -78,11 +78,23 @@ export async function getOrgUsage(
|
||||
// Get usage for org
|
||||
const usageData = [];
|
||||
|
||||
const siteUptime = await usageService.getUsage(orgId, FeatureId.SITE_UPTIME);
|
||||
const siteUptime = await usageService.getUsage(
|
||||
orgId,
|
||||
FeatureId.SITE_UPTIME
|
||||
);
|
||||
const users = await usageService.getUsageDaily(orgId, FeatureId.USERS);
|
||||
const domains = await usageService.getUsageDaily(orgId, FeatureId.DOMAINS);
|
||||
const remoteExitNodes = await usageService.getUsageDaily(orgId, FeatureId.REMOTE_EXIT_NODES);
|
||||
const egressData = await usageService.getUsage(orgId, FeatureId.EGRESS_DATA_MB);
|
||||
const domains = await usageService.getUsageDaily(
|
||||
orgId,
|
||||
FeatureId.DOMAINS
|
||||
);
|
||||
const remoteExitNodes = await usageService.getUsageDaily(
|
||||
orgId,
|
||||
FeatureId.REMOTE_EXIT_NODES
|
||||
);
|
||||
const egressData = await usageService.getUsage(
|
||||
orgId,
|
||||
FeatureId.EGRESS_DATA_MB
|
||||
);
|
||||
|
||||
if (siteUptime) {
|
||||
usageData.push(siteUptime);
|
||||
@@ -100,7 +112,8 @@ export async function getOrgUsage(
|
||||
usageData.push(remoteExitNodes);
|
||||
}
|
||||
|
||||
const orgLimits = await db.select()
|
||||
const orgLimits = await db
|
||||
.select()
|
||||
.from(limits)
|
||||
.where(eq(limits.orgId, orgId));
|
||||
|
||||
|
||||
@@ -31,9 +31,7 @@ export async function handleCustomerDeleted(
|
||||
return;
|
||||
}
|
||||
|
||||
await db
|
||||
.delete(customers)
|
||||
.where(eq(customers.customerId, customer.id));
|
||||
await db.delete(customers).where(eq(customers.customerId, customer.id));
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error handling customer created event for ID ${customer.id}:`,
|
||||
|
||||
@@ -12,7 +12,14 @@
|
||||
*/
|
||||
|
||||
import Stripe from "stripe";
|
||||
import { subscriptions, db, subscriptionItems, customers, userOrgs, users } from "@server/db";
|
||||
import {
|
||||
subscriptions,
|
||||
db,
|
||||
subscriptionItems,
|
||||
customers,
|
||||
userOrgs,
|
||||
users
|
||||
} from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
|
||||
@@ -43,7 +50,6 @@ export async function handleSubscriptionDeleted(
|
||||
.delete(subscriptionItems)
|
||||
.where(eq(subscriptionItems.subscriptionId, subscription.id));
|
||||
|
||||
|
||||
// Lookup customer to get orgId
|
||||
const [customer] = await db
|
||||
.select()
|
||||
@@ -58,10 +64,7 @@ export async function handleSubscriptionDeleted(
|
||||
return;
|
||||
}
|
||||
|
||||
await handleSubscriptionLifesycle(
|
||||
customer.orgId,
|
||||
subscription.status
|
||||
);
|
||||
await handleSubscriptionLifesycle(customer.orgId, subscription.status);
|
||||
|
||||
const [orgUserRes] = await db
|
||||
.select()
|
||||
|
||||
@@ -15,4 +15,4 @@ export * from "./createCheckoutSession";
|
||||
export * from "./createPortalSession";
|
||||
export * from "./getOrgSubscription";
|
||||
export * from "./getOrgUsage";
|
||||
export * from "./internalGetOrgTier";
|
||||
export * from "./internalGetOrgTier";
|
||||
|
||||
@@ -22,8 +22,8 @@ import { getOrgTierData } from "#private/lib/billing";
|
||||
import { GetOrgTierResponse } from "@server/routers/billing/types";
|
||||
|
||||
const getOrgSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export async function getOrgTier(
|
||||
req: Request,
|
||||
|
||||
@@ -11,11 +11,18 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { freeLimitSet, limitsService, subscribedLimitSet } from "@server/lib/billing";
|
||||
import {
|
||||
freeLimitSet,
|
||||
limitsService,
|
||||
subscribedLimitSet
|
||||
} from "@server/lib/billing";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export async function handleSubscriptionLifesycle(orgId: string, status: string) {
|
||||
export async function handleSubscriptionLifesycle(
|
||||
orgId: string,
|
||||
status: string
|
||||
) {
|
||||
switch (status) {
|
||||
case "active":
|
||||
await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet);
|
||||
@@ -42,4 +49,4 @@ export async function handleSubscriptionLifesycle(orgId: string, status: string)
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,12 +32,13 @@ export async function billingWebhookHandler(
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
let event: Stripe.Event = req.body;
|
||||
const endpointSecret = privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
|
||||
const endpointSecret =
|
||||
privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
|
||||
if (!endpointSecret) {
|
||||
logger.warn("Stripe webhook secret is not configured. Webhook events will not be priocessed.");
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "")
|
||||
logger.warn(
|
||||
"Stripe webhook secret is not configured. Webhook events will not be priocessed."
|
||||
);
|
||||
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, ""));
|
||||
}
|
||||
|
||||
// Only verify the event if you have an endpoint secret defined.
|
||||
@@ -49,7 +50,10 @@ export async function billingWebhookHandler(
|
||||
if (!signature) {
|
||||
logger.info("No stripe signature found in headers.");
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "No stripe signature found in headers")
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"No stripe signature found in headers"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -62,7 +66,10 @@ export async function billingWebhookHandler(
|
||||
} catch (err) {
|
||||
logger.error(`Webhook signature verification failed.`, err);
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "Webhook signature verification failed")
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
"Webhook signature verification failed"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,10 +24,10 @@ import { registry } from "@server/openApi";
|
||||
import { GetCertificateResponse } from "@server/routers/certificates/types";
|
||||
|
||||
const getCertificateSchema = z.strictObject({
|
||||
domainId: z.string(),
|
||||
domain: z.string().min(1).max(255),
|
||||
orgId: z.string()
|
||||
});
|
||||
domainId: z.string(),
|
||||
domain: z.string().min(1).max(255),
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
async function query(domainId: string, domain: string) {
|
||||
const [domainRecord] = await db
|
||||
@@ -42,8 +42,8 @@ async function query(domainId: string, domain: string) {
|
||||
|
||||
let existing: any[] = [];
|
||||
if (domainRecord.type == "ns") {
|
||||
const domainLevelDown = domain.split('.').slice(1).join('.');
|
||||
|
||||
const domainLevelDown = domain.split(".").slice(1).join(".");
|
||||
|
||||
existing = await db
|
||||
.select({
|
||||
certId: certificates.certId,
|
||||
@@ -64,7 +64,7 @@ async function query(domainId: string, domain: string) {
|
||||
eq(certificates.wildcard, true), // only NS domains can have wildcard certs
|
||||
or(
|
||||
eq(certificates.domain, domain),
|
||||
eq(certificates.domain, domainLevelDown),
|
||||
eq(certificates.domain, domainLevelDown)
|
||||
)
|
||||
)
|
||||
);
|
||||
@@ -102,8 +102,7 @@ registry.registerPath({
|
||||
tags: ["Certificate"],
|
||||
request: {
|
||||
params: z.object({
|
||||
domainId: z
|
||||
.string(),
|
||||
domainId: z.string(),
|
||||
domain: z.string().min(1).max(255),
|
||||
orgId: z.string()
|
||||
})
|
||||
@@ -133,7 +132,9 @@ export async function getCertificate(
|
||||
|
||||
if (!cert) {
|
||||
logger.warn(`Certificate not found for domain: ${domainId}`);
|
||||
return next(createHttpError(HttpCode.NOT_FOUND, "Certificate not found"));
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Certificate not found")
|
||||
);
|
||||
}
|
||||
|
||||
return response<GetCertificateResponse>(res, {
|
||||
|
||||
@@ -12,4 +12,4 @@
|
||||
*/
|
||||
|
||||
export * from "./getCertificate";
|
||||
export * from "./restartCertificate";
|
||||
export * from "./restartCertificate";
|
||||
|
||||
@@ -25,9 +25,9 @@ import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const restartCertificateParamsSchema = z.strictObject({
|
||||
certId: z.string().transform(stoi).pipe(z.int().positive()),
|
||||
orgId: z.string()
|
||||
});
|
||||
certId: z.string().transform(stoi).pipe(z.int().positive()),
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
@@ -36,10 +36,7 @@ registry.registerPath({
|
||||
tags: ["Certificate"],
|
||||
request: {
|
||||
params: z.object({
|
||||
certId: z
|
||||
.string()
|
||||
.transform(stoi)
|
||||
.pipe(z.int().positive()),
|
||||
certId: z.string().transform(stoi).pipe(z.int().positive()),
|
||||
orgId: z.string()
|
||||
})
|
||||
},
|
||||
@@ -94,7 +91,7 @@ export async function restartCertificate(
|
||||
.set({
|
||||
status: "pending",
|
||||
errorMessage: null,
|
||||
lastRenewalAttempt: Math.floor(Date.now() / 1000)
|
||||
lastRenewalAttempt: Math.floor(Date.now() / 1000)
|
||||
})
|
||||
.where(eq(certificates.certId, certId));
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ import { CheckDomainAvailabilityResponse } from "@server/routers/domain/types";
|
||||
const paramsSchema = z.strictObject({});
|
||||
|
||||
const querySchema = z.strictObject({
|
||||
subdomain: z.string()
|
||||
});
|
||||
subdomain: z.string()
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
|
||||
@@ -12,4 +12,4 @@
|
||||
*/
|
||||
|
||||
export * from "./checkDomainNamespaceAvailability";
|
||||
export * from "./listDomainNamespaces";
|
||||
export * from "./listDomainNamespaces";
|
||||
|
||||
@@ -26,19 +26,19 @@ import { OpenAPITags, registry } from "@server/openApi";
|
||||
const paramsSchema = z.strictObject({});
|
||||
|
||||
const querySchema = z.strictObject({
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative())
|
||||
});
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative())
|
||||
});
|
||||
|
||||
async function query(limit: number, offset: number) {
|
||||
const res = await db
|
||||
|
||||
@@ -30,7 +30,7 @@ import {
|
||||
verifyUserHasAction,
|
||||
verifyUserIsServerAdmin,
|
||||
verifySiteAccess,
|
||||
verifyClientAccess,
|
||||
verifyClientAccess
|
||||
} from "@server/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
import {
|
||||
@@ -436,6 +436,8 @@ authenticated.get(
|
||||
|
||||
authenticated.post(
|
||||
"/re-key/:clientId/regenerate-client-secret",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyClientAccess,
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateClientSecret
|
||||
@@ -443,15 +445,18 @@ authenticated.post(
|
||||
|
||||
authenticated.post(
|
||||
"/re-key/:siteId/regenerate-site-secret",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifySiteAccess,
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateSiteSecret
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/re-key/:orgId/reGenerate-remote-exit-node-secret",
|
||||
"/re-key/:orgId/regenerate-remote-exit-node-secret",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.updateRemoteExitNode),
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateExitNodeSecret
|
||||
);
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
@@ -79,86 +79,72 @@ import semver from "semver";
|
||||
|
||||
// Zod schemas for request validation
|
||||
const getResourceByDomainParamsSchema = z.strictObject({
|
||||
domain: z.string().min(1, "Domain is required")
|
||||
});
|
||||
domain: z.string().min(1, "Domain is required")
|
||||
});
|
||||
|
||||
const getUserSessionParamsSchema = z.strictObject({
|
||||
userSessionId: z.string().min(1, "User session ID is required")
|
||||
});
|
||||
userSessionId: z.string().min(1, "User session ID is required")
|
||||
});
|
||||
|
||||
const getUserOrgRoleParamsSchema = z.strictObject({
|
||||
userId: z.string().min(1, "User ID is required"),
|
||||
orgId: z.string().min(1, "Organization ID is required")
|
||||
});
|
||||
userId: z.string().min(1, "User ID is required"),
|
||||
orgId: z.string().min(1, "Organization ID is required")
|
||||
});
|
||||
|
||||
const getRoleResourceAccessParamsSchema = z.strictObject({
|
||||
roleId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(
|
||||
z.int().positive("Role ID must be a positive integer")
|
||||
),
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(
|
||||
z.int()
|
||||
.positive("Resource ID must be a positive integer")
|
||||
)
|
||||
});
|
||||
roleId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.int().positive("Role ID must be a positive integer")),
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.int().positive("Resource ID must be a positive integer"))
|
||||
});
|
||||
|
||||
const getUserResourceAccessParamsSchema = z.strictObject({
|
||||
userId: z.string().min(1, "User ID is required"),
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(
|
||||
z.int()
|
||||
.positive("Resource ID must be a positive integer")
|
||||
)
|
||||
});
|
||||
userId: z.string().min(1, "User ID is required"),
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.int().positive("Resource ID must be a positive integer"))
|
||||
});
|
||||
|
||||
const getResourceRulesParamsSchema = z.strictObject({
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(
|
||||
z.int()
|
||||
.positive("Resource ID must be a positive integer")
|
||||
)
|
||||
});
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.int().positive("Resource ID must be a positive integer"))
|
||||
});
|
||||
|
||||
const validateResourceSessionTokenParamsSchema = z.strictObject({
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(
|
||||
z.int()
|
||||
.positive("Resource ID must be a positive integer")
|
||||
)
|
||||
});
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.int().positive("Resource ID must be a positive integer"))
|
||||
});
|
||||
|
||||
const validateResourceSessionTokenBodySchema = z.strictObject({
|
||||
token: z.string().min(1, "Token is required")
|
||||
});
|
||||
token: z.string().min(1, "Token is required")
|
||||
});
|
||||
|
||||
const validateResourceAccessTokenBodySchema = z.strictObject({
|
||||
accessTokenId: z.string().optional(),
|
||||
resourceId: z.number().optional(),
|
||||
accessToken: z.string()
|
||||
});
|
||||
accessTokenId: z.string().optional(),
|
||||
resourceId: z.number().optional(),
|
||||
accessToken: z.string()
|
||||
});
|
||||
|
||||
// Certificates by domains query validation
|
||||
const getCertificatesByDomainsQuerySchema = z.strictObject({
|
||||
// Accept domains as string or array (domains or domains[])
|
||||
domains: z
|
||||
.union([z.array(z.string().min(1)), z.string().min(1)])
|
||||
.optional(),
|
||||
// Handle array format from query parameters (domains[])
|
||||
"domains[]": z
|
||||
.union([z.array(z.string().min(1)), z.string().min(1)])
|
||||
.optional()
|
||||
});
|
||||
// Accept domains as string or array (domains or domains[])
|
||||
domains: z
|
||||
.union([z.array(z.string().min(1)), z.string().min(1)])
|
||||
.optional(),
|
||||
// Handle array format from query parameters (domains[])
|
||||
"domains[]": z
|
||||
.union([z.array(z.string().min(1)), z.string().min(1)])
|
||||
.optional()
|
||||
});
|
||||
|
||||
// Type exports for request schemas
|
||||
export type GetResourceByDomainParams = z.infer<
|
||||
@@ -566,8 +552,8 @@ hybridRouter.get(
|
||||
);
|
||||
|
||||
const getOrgLoginPageParamsSchema = z.strictObject({
|
||||
orgId: z.string().min(1)
|
||||
});
|
||||
orgId: z.string().min(1)
|
||||
});
|
||||
|
||||
hybridRouter.get(
|
||||
"/org/:orgId/login-page",
|
||||
@@ -1408,8 +1394,16 @@ hybridRouter.post(
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } =
|
||||
parsedParams.data;
|
||||
const {
|
||||
olmId,
|
||||
newtId,
|
||||
ip,
|
||||
port,
|
||||
timestamp,
|
||||
token,
|
||||
publicKey,
|
||||
reachableAt
|
||||
} = parsedParams.data;
|
||||
|
||||
const destinations = await updateAndGenerateEndpointDestinations(
|
||||
olmId,
|
||||
|
||||
@@ -18,7 +18,7 @@ import * as logs from "#private/routers/auditLogs";
|
||||
import {
|
||||
verifyApiKeyHasAction,
|
||||
verifyApiKeyIsRoot,
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyOrgAccess
|
||||
} from "@server/middlewares";
|
||||
import {
|
||||
verifyValidSubscription,
|
||||
@@ -26,7 +26,10 @@ import {
|
||||
} from "#private/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
|
||||
import { unauthenticated as ua, authenticated as a } from "@server/routers/integration";
|
||||
import {
|
||||
unauthenticated as ua,
|
||||
authenticated as a
|
||||
} from "@server/routers/integration";
|
||||
import { logActionAudit } from "#private/middlewares";
|
||||
|
||||
export const unauthenticated = ua;
|
||||
@@ -37,7 +40,7 @@ authenticated.post(
|
||||
verifyApiKeyIsRoot, // We are the only ones who can use root key so its fine
|
||||
verifyApiKeyHasAction(ActionsEnum.sendUsageNotification),
|
||||
logActionAudit(ActionsEnum.sendUsageNotification),
|
||||
org.sendUsageNotification,
|
||||
org.sendUsageNotification
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
@@ -45,7 +48,7 @@ authenticated.delete(
|
||||
verifyApiKeyIsRoot,
|
||||
verifyApiKeyHasAction(ActionsEnum.deleteIdp),
|
||||
logActionAudit(ActionsEnum.deleteIdp),
|
||||
orgIdp.deleteOrgIdp,
|
||||
orgIdp.deleteOrgIdp
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
|
||||
@@ -21,8 +21,8 @@ import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
licenseKey: z.string().min(1).max(255)
|
||||
});
|
||||
licenseKey: z.string().min(1).max(255)
|
||||
});
|
||||
|
||||
export async function activateLicense(
|
||||
req: Request,
|
||||
|
||||
@@ -24,8 +24,8 @@ import { licenseKey } from "@server/db";
|
||||
import license from "#private/license/license";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
licenseKey: z.string().min(1).max(255)
|
||||
});
|
||||
licenseKey: z.string().min(1).max(255)
|
||||
});
|
||||
|
||||
export async function deleteLicenseKey(
|
||||
req: Request,
|
||||
|
||||
@@ -36,13 +36,13 @@ import { build } from "@server/build";
|
||||
import { CreateLoginPageResponse } from "@server/routers/loginPage/types";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
subdomain: z.string().nullable().optional(),
|
||||
domainId: z.string()
|
||||
});
|
||||
subdomain: z.string().nullable().optional(),
|
||||
domainId: z.string()
|
||||
});
|
||||
|
||||
export type CreateLoginPageBody = z.infer<typeof bodySchema>;
|
||||
|
||||
@@ -149,12 +149,20 @@ export async function createLoginPage(
|
||||
|
||||
let returned: LoginPage | undefined;
|
||||
await db.transaction(async (trx) => {
|
||||
|
||||
const orgSites = await trx
|
||||
.select()
|
||||
.from(sites)
|
||||
.innerJoin(exitNodes, eq(exitNodes.exitNodeId, sites.exitNodeId))
|
||||
.where(and(eq(sites.orgId, orgId), eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
|
||||
.innerJoin(
|
||||
exitNodes,
|
||||
eq(exitNodes.exitNodeId, sites.exitNodeId)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(sites.orgId, orgId),
|
||||
eq(exitNodes.type, "gerbil"),
|
||||
eq(exitNodes.online, true)
|
||||
)
|
||||
)
|
||||
.limit(10);
|
||||
|
||||
let exitNodesList = orgSites.map((s) => s.exitNodes);
|
||||
@@ -163,7 +171,12 @@ export async function createLoginPage(
|
||||
exitNodesList = await trx
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(and(eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
|
||||
.where(
|
||||
and(
|
||||
eq(exitNodes.type, "gerbil"),
|
||||
eq(exitNodes.online, true)
|
||||
)
|
||||
)
|
||||
.limit(10);
|
||||
}
|
||||
|
||||
|
||||
@@ -78,15 +78,11 @@ export async function deleteLoginPage(
|
||||
// if (!leftoverLinks.length) {
|
||||
await db
|
||||
.delete(loginPage)
|
||||
.where(
|
||||
eq(loginPage.loginPageId, parsedParams.data.loginPageId)
|
||||
);
|
||||
.where(eq(loginPage.loginPageId, parsedParams.data.loginPageId));
|
||||
|
||||
await db
|
||||
.delete(loginPageOrg)
|
||||
.where(
|
||||
eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId)
|
||||
);
|
||||
.where(eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId));
|
||||
// }
|
||||
|
||||
return response<LoginPage>(res, {
|
||||
|
||||
@@ -23,8 +23,8 @@ import { fromError } from "zod-validation-error";
|
||||
import { GetLoginPageResponse } from "@server/routers/loginPage/types";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
async function query(orgId: string) {
|
||||
const [res] = await db
|
||||
|
||||
@@ -35,7 +35,8 @@ const paramsSchema = z
|
||||
})
|
||||
.strict();
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
const bodySchema = z
|
||||
.strictObject({
|
||||
subdomain: subdomainSchema.nullable().optional(),
|
||||
domainId: z.string().optional()
|
||||
})
|
||||
@@ -86,7 +87,7 @@ export async function updateLoginPage(
|
||||
|
||||
const { loginPageId, orgId } = parsedParams.data;
|
||||
|
||||
if (build === "saas"){
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
@@ -182,7 +183,10 @@ export async function updateLoginPage(
|
||||
}
|
||||
|
||||
// update the full domain if it has changed
|
||||
if (fullDomain && fullDomain !== existingLoginPage?.fullDomain) {
|
||||
if (
|
||||
fullDomain &&
|
||||
fullDomain !== existingLoginPage?.fullDomain
|
||||
) {
|
||||
await db
|
||||
.update(loginPage)
|
||||
.set({ fullDomain })
|
||||
|
||||
@@ -23,9 +23,9 @@ import SupportEmail from "@server/emails/templates/SupportEmail";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
body: z.string().min(1),
|
||||
subject: z.string().min(1).max(255)
|
||||
});
|
||||
body: z.string().min(1),
|
||||
subject: z.string().min(1).max(255)
|
||||
});
|
||||
|
||||
export async function sendSupportEmail(
|
||||
req: Request,
|
||||
@@ -66,6 +66,7 @@ export async function sendSupportEmail(
|
||||
{
|
||||
name: req.user?.email || "Support User",
|
||||
to: "support@pangolin.net",
|
||||
replyTo: req.user?.email || undefined,
|
||||
from: config.getNoReplyEmail(),
|
||||
subject: `Support Request: ${subject}`
|
||||
}
|
||||
|
||||
@@ -11,4 +11,4 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./sendUsageNotifications";
|
||||
export * from "./sendUsageNotifications";
|
||||
|
||||
@@ -35,10 +35,12 @@ const sendUsageNotificationBodySchema = z.object({
|
||||
notificationType: z.enum(["approaching_70", "approaching_90", "reached"]),
|
||||
limitName: z.string(),
|
||||
currentUsage: z.number(),
|
||||
usageLimit: z.number(),
|
||||
usageLimit: z.number()
|
||||
});
|
||||
|
||||
type SendUsageNotificationRequest = z.infer<typeof sendUsageNotificationBodySchema>;
|
||||
type SendUsageNotificationRequest = z.infer<
|
||||
typeof sendUsageNotificationBodySchema
|
||||
>;
|
||||
|
||||
export type SendUsageNotificationResponse = {
|
||||
success: boolean;
|
||||
@@ -97,17 +99,13 @@ async function getOrgAdmins(orgId: string) {
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.orgId, orgId),
|
||||
or(
|
||||
eq(userOrgs.isOwner, true),
|
||||
eq(roles.isAdmin, true)
|
||||
)
|
||||
or(eq(userOrgs.isOwner, true), eq(roles.isAdmin, true))
|
||||
)
|
||||
);
|
||||
|
||||
// Filter to only include users with verified emails
|
||||
const orgAdmins = admins.filter(admin =>
|
||||
admin.email &&
|
||||
admin.email.length > 0
|
||||
const orgAdmins = admins.filter(
|
||||
(admin) => admin.email && admin.email.length > 0
|
||||
);
|
||||
|
||||
return orgAdmins;
|
||||
@@ -119,7 +117,9 @@ export async function sendUsageNotification(
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = sendUsageNotificationParamsSchema.safeParse(req.params);
|
||||
const parsedParams = sendUsageNotificationParamsSchema.safeParse(
|
||||
req.params
|
||||
);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
@@ -140,12 +140,8 @@ export async function sendUsageNotification(
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
const {
|
||||
notificationType,
|
||||
limitName,
|
||||
currentUsage,
|
||||
usageLimit,
|
||||
} = parsedBody.data;
|
||||
const { notificationType, limitName, currentUsage, usageLimit } =
|
||||
parsedBody.data;
|
||||
|
||||
// Verify organization exists
|
||||
const org = await db
|
||||
@@ -192,7 +188,10 @@ export async function sendUsageNotification(
|
||||
let template;
|
||||
let subject;
|
||||
|
||||
if (notificationType === "approaching_70" || notificationType === "approaching_90") {
|
||||
if (
|
||||
notificationType === "approaching_70" ||
|
||||
notificationType === "approaching_90"
|
||||
) {
|
||||
template = NotifyUsageLimitApproaching({
|
||||
email: admin.email,
|
||||
limitName,
|
||||
@@ -220,10 +219,15 @@ export async function sendUsageNotification(
|
||||
|
||||
emailsSent++;
|
||||
adminEmails.push(admin.email);
|
||||
|
||||
logger.info(`Usage notification sent to admin ${admin.email} for org ${orgId}`);
|
||||
|
||||
logger.info(
|
||||
`Usage notification sent to admin ${admin.email} for org ${orgId}`
|
||||
);
|
||||
} catch (emailError) {
|
||||
logger.error(`Failed to send usage notification to ${admin.email}:`, emailError);
|
||||
logger.error(
|
||||
`Failed to send usage notification to ${admin.email}:`,
|
||||
emailError
|
||||
);
|
||||
// Continue with other admins even if one fails
|
||||
}
|
||||
}
|
||||
@@ -239,11 +243,13 @@ export async function sendUsageNotification(
|
||||
message: `Usage notifications sent to ${emailsSent} administrators`,
|
||||
status: HttpCode.OK
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
logger.error("Error sending usage notifications:", error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to send usage notifications")
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to send usage notifications"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,19 +32,19 @@ import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types";
|
||||
const paramsSchema = z.strictObject({ orgId: z.string().nonempty() });
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
name: z.string().nonempty(),
|
||||
clientId: z.string().nonempty(),
|
||||
clientSecret: z.string().nonempty(),
|
||||
authUrl: z.url(),
|
||||
tokenUrl: z.url(),
|
||||
identifierPath: z.string().nonempty(),
|
||||
emailPath: z.string().optional(),
|
||||
namePath: z.string().optional(),
|
||||
scopes: z.string().nonempty(),
|
||||
autoProvision: z.boolean().optional(),
|
||||
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
|
||||
roleMapping: z.string().optional()
|
||||
});
|
||||
name: z.string().nonempty(),
|
||||
clientId: z.string().nonempty(),
|
||||
clientSecret: z.string().nonempty(),
|
||||
authUrl: z.url(),
|
||||
tokenUrl: z.url(),
|
||||
identifierPath: z.string().nonempty(),
|
||||
emailPath: z.string().optional(),
|
||||
namePath: z.string().optional(),
|
||||
scopes: z.string().nonempty(),
|
||||
autoProvision: z.boolean().optional(),
|
||||
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
|
||||
roleMapping: z.string().optional()
|
||||
});
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "put",
|
||||
@@ -158,7 +158,10 @@ export async function createOrgOidcIdp(
|
||||
});
|
||||
});
|
||||
|
||||
const redirectUrl = await generateOidcRedirectUrl(idpId as number, orgId);
|
||||
const redirectUrl = await generateOidcRedirectUrl(
|
||||
idpId as number,
|
||||
orgId
|
||||
);
|
||||
|
||||
return response<CreateOrgIdpResponse>(res, {
|
||||
data: {
|
||||
|
||||
@@ -66,12 +66,7 @@ export async function deleteOrgIdp(
|
||||
.where(eq(idp.idpId, idpId));
|
||||
|
||||
if (!existingIdp) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"IdP not found"
|
||||
)
|
||||
);
|
||||
return next(createHttpError(HttpCode.NOT_FOUND, "IdP not found"));
|
||||
}
|
||||
|
||||
// Delete the IDP and its related records in a transaction
|
||||
@@ -82,14 +77,10 @@ export async function deleteOrgIdp(
|
||||
.where(eq(idpOidcConfig.idpId, idpId));
|
||||
|
||||
// Delete IDP-org mappings
|
||||
await trx
|
||||
.delete(idpOrg)
|
||||
.where(eq(idpOrg.idpId, idpId));
|
||||
await trx.delete(idpOrg).where(eq(idpOrg.idpId, idpId));
|
||||
|
||||
// Delete the IDP itself
|
||||
await trx
|
||||
.delete(idp)
|
||||
.where(eq(idp.idpId, idpId));
|
||||
await trx.delete(idp).where(eq(idp.idpId, idpId));
|
||||
});
|
||||
|
||||
return response<null>(res, {
|
||||
|
||||
@@ -93,7 +93,10 @@ export async function getOrgIdp(
|
||||
idpRes.idpOidcConfig!.clientId = decrypt(clientId, key);
|
||||
}
|
||||
|
||||
const redirectUrl = await generateOidcRedirectUrl(idpRes.idp.idpId, orgId);
|
||||
const redirectUrl = await generateOidcRedirectUrl(
|
||||
idpRes.idp.idpId,
|
||||
orgId
|
||||
);
|
||||
|
||||
return response<GetOrgIdpResponse>(res, {
|
||||
data: {
|
||||
|
||||
@@ -15,4 +15,4 @@ export * from "./createOrgOidcIdp";
|
||||
export * from "./getOrgIdp";
|
||||
export * from "./listOrgIdps";
|
||||
export * from "./updateOrgOidcIdp";
|
||||
export * from "./deleteOrgIdp";
|
||||
export * from "./deleteOrgIdp";
|
||||
|
||||
@@ -25,23 +25,23 @@ import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { ListOrgIdpsResponse } from "@server/routers/orgIdp/types";
|
||||
|
||||
const querySchema = z.strictObject({
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative())
|
||||
});
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative())
|
||||
});
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string().nonempty()
|
||||
});
|
||||
orgId: z.string().nonempty()
|
||||
});
|
||||
|
||||
async function query(orgId: string, limit: number, offset: number) {
|
||||
const res = await db
|
||||
|
||||
@@ -36,18 +36,18 @@ const paramsSchema = z
|
||||
.strict();
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
name: z.string().optional(),
|
||||
clientId: z.string().optional(),
|
||||
clientSecret: z.string().optional(),
|
||||
authUrl: z.string().optional(),
|
||||
tokenUrl: z.string().optional(),
|
||||
identifierPath: z.string().optional(),
|
||||
emailPath: z.string().optional(),
|
||||
namePath: z.string().optional(),
|
||||
scopes: z.string().optional(),
|
||||
autoProvision: z.boolean().optional(),
|
||||
roleMapping: z.string().optional()
|
||||
});
|
||||
name: z.string().optional(),
|
||||
clientId: z.string().optional(),
|
||||
clientSecret: z.string().optional(),
|
||||
authUrl: z.string().optional(),
|
||||
tokenUrl: z.string().optional(),
|
||||
identifierPath: z.string().optional(),
|
||||
emailPath: z.string().optional(),
|
||||
namePath: z.string().optional(),
|
||||
scopes: z.string().optional(),
|
||||
autoProvision: z.boolean().optional(),
|
||||
roleMapping: z.string().optional()
|
||||
});
|
||||
|
||||
export type UpdateOrgIdpResponse = {
|
||||
idpId: number;
|
||||
|
||||
@@ -13,4 +13,4 @@
|
||||
|
||||
export * from "./reGenerateClientSecret";
|
||||
export * from "./reGenerateSiteSecret";
|
||||
export * from "./reGenerateExitNodeSecret";
|
||||
export * from "./reGenerateExitNodeSecret";
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, olms, } from "@server/db";
|
||||
import { db, Olm, olms } from "@server/db";
|
||||
import { clients } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
@@ -23,37 +23,19 @@ import { eq, and } from "drizzle-orm";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import { disconnectClient, sendToClient } from "#private/routers/ws";
|
||||
|
||||
const reGenerateSecretParamsSchema = z.strictObject({
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
const reGenerateSecretBodySchema = z.strictObject({
|
||||
olmId: z.string().min(1).optional(),
|
||||
secret: z.string().min(1).optional(),
|
||||
|
||||
});
|
||||
|
||||
export type ReGenerateSecretBody = z.infer<typeof reGenerateSecretBodySchema>;
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/re-key/{clientId}/regenerate-client-secret",
|
||||
description: "Regenerate a client's OLM credentials by its client ID.",
|
||||
tags: [OpenAPITags.Client],
|
||||
request: {
|
||||
params: reGenerateSecretParamsSchema,
|
||||
body: {
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: reGenerateSecretBodySchema
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
responses: {}
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
const reGenerateSecretBodySchema = z.strictObject({
|
||||
// olmId: z.string().min(1).optional(),
|
||||
secret: z.string().min(1),
|
||||
disconnect: z.boolean().optional().default(true)
|
||||
});
|
||||
|
||||
export type ReGenerateSecretBody = z.infer<typeof reGenerateSecretBodySchema>;
|
||||
|
||||
export async function reGenerateClientSecret(
|
||||
req: Request,
|
||||
@@ -71,7 +53,7 @@ export async function reGenerateClientSecret(
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId, secret } = parsedBody.data;
|
||||
const { secret, disconnect } = parsedBody.data;
|
||||
|
||||
const parsedParams = reGenerateSecretParamsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
@@ -85,11 +67,7 @@ export async function reGenerateClientSecret(
|
||||
|
||||
const { clientId } = parsedParams.data;
|
||||
|
||||
let secretHash = undefined;
|
||||
if (secret) {
|
||||
secretHash = await hashPassword(secret);
|
||||
}
|
||||
|
||||
const secretHash = await hashPassword(secret);
|
||||
|
||||
// Fetch the client to make sure it exists and the user has access to it
|
||||
const [client] = await db
|
||||
@@ -107,24 +85,59 @@ export async function reGenerateClientSecret(
|
||||
);
|
||||
}
|
||||
|
||||
const [existingOlm] = await db
|
||||
const existingOlms = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
.where(eq(olms.clientId, clientId));
|
||||
|
||||
if (existingOlm && olmId && secretHash) {
|
||||
await db
|
||||
.update(olms)
|
||||
.set({
|
||||
olmId,
|
||||
secretHash
|
||||
})
|
||||
.where(eq(olms.clientId, clientId));
|
||||
if (existingOlms.length === 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`No OLM found for client ID ${clientId}`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (existingOlms.length > 1) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
`Multiple OLM entries found for client ID ${clientId}`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await db
|
||||
.update(olms)
|
||||
.set({
|
||||
secretHash
|
||||
})
|
||||
.where(eq(olms.olmId, existingOlms[0].olmId));
|
||||
|
||||
// Only disconnect if explicitly requested
|
||||
if (disconnect) {
|
||||
const payload = {
|
||||
type: `olm/terminate`,
|
||||
data: {}
|
||||
};
|
||||
// Don't await this to prevent blocking the response
|
||||
sendToClient(existingOlms[0].olmId, payload).catch((error) => {
|
||||
logger.error(
|
||||
"Failed to send termination message to olm:",
|
||||
error
|
||||
);
|
||||
});
|
||||
|
||||
disconnectClient(existingOlms[0].olmId).catch((error) => {
|
||||
logger.error("Failed to disconnect olm after re-key:", error);
|
||||
});
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: existingOlm,
|
||||
data: {
|
||||
olmId: existingOlms[0].olmId
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Credentials regenerated successfully",
|
||||
|
||||
@@ -12,7 +12,14 @@
|
||||
*/
|
||||
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg } from "@server/db";
|
||||
import {
|
||||
db,
|
||||
exitNodes,
|
||||
exitNodeOrgs,
|
||||
ExitNode,
|
||||
ExitNodeOrg,
|
||||
RemoteExitNode
|
||||
} from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { z } from "zod";
|
||||
import { remoteExitNodes } from "@server/db";
|
||||
@@ -22,35 +29,17 @@ import { fromError } from "zod-validation-error";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import logger from "@server/logger";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { UpdateRemoteExitNodeResponse } from "@server/routers/remoteExitNode/types";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { disconnectClient, sendToClient } from "#private/routers/ws";
|
||||
|
||||
export const paramsSchema = z.object({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
remoteExitNodeId: z.string().length(15),
|
||||
secret: z.string().length(48)
|
||||
});
|
||||
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/re-key/{orgId}/regenerate-secret",
|
||||
description: "Regenerate a exit node credentials by its org ID.",
|
||||
tags: [OpenAPITags.Org],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
body: {
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: bodySchema
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
responses: {}
|
||||
remoteExitNodeId: z.string().length(15),
|
||||
secret: z.string().length(48),
|
||||
disconnect: z.boolean().optional().default(true)
|
||||
});
|
||||
|
||||
export async function reGenerateExitNodeSecret(
|
||||
@@ -79,13 +68,7 @@ export async function reGenerateExitNodeSecret(
|
||||
);
|
||||
}
|
||||
|
||||
const { remoteExitNodeId, secret } = parsedBody.data;
|
||||
|
||||
if (req.user && !req.userOrgRoleId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
|
||||
);
|
||||
}
|
||||
const { remoteExitNodeId, secret, disconnect } = parsedBody.data;
|
||||
|
||||
const [existingRemoteExitNode] = await db
|
||||
.select()
|
||||
@@ -94,7 +77,10 @@ export async function reGenerateExitNodeSecret(
|
||||
|
||||
if (!existingRemoteExitNode) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Remote Exit Node does not exist")
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"Remote Exit Node does not exist"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -105,15 +91,39 @@ export async function reGenerateExitNodeSecret(
|
||||
.set({ secretHash })
|
||||
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId));
|
||||
|
||||
return response<UpdateRemoteExitNodeResponse>(res, {
|
||||
data: {
|
||||
remoteExitNodeId,
|
||||
secret,
|
||||
},
|
||||
// Only disconnect if explicitly requested
|
||||
if (disconnect) {
|
||||
const payload = {
|
||||
type: `remoteExitNode/terminate`,
|
||||
data: {}
|
||||
};
|
||||
// Don't await this to prevent blocking the response
|
||||
sendToClient(
|
||||
existingRemoteExitNode.remoteExitNodeId,
|
||||
payload
|
||||
).catch((error) => {
|
||||
logger.error(
|
||||
"Failed to send termination message to remote exit node:",
|
||||
error
|
||||
);
|
||||
});
|
||||
|
||||
disconnectClient(existingRemoteExitNode.remoteExitNodeId).catch(
|
||||
(error) => {
|
||||
logger.error(
|
||||
"Failed to disconnect remote exit node after re-key:",
|
||||
error
|
||||
);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Remote Exit Node secret updated successfully",
|
||||
status: HttpCode.OK,
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
logger.error("Failed to update remoteExitNode", e);
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, newts, sites } from "@server/db";
|
||||
import { db, Newt, newts, sites } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
@@ -22,38 +22,19 @@ import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import { addPeer } from "@server/routers/gerbil/peers";
|
||||
|
||||
import { addPeer, deletePeer } from "@server/routers/gerbil/peers";
|
||||
import { getAllowedIps } from "@server/routers/target/helpers";
|
||||
import { disconnectClient, sendToClient } from "#private/routers/ws";
|
||||
|
||||
const updateSiteParamsSchema = z.strictObject({
|
||||
siteId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
siteId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
const updateSiteBodySchema = z.strictObject({
|
||||
type: z.enum(["newt", "wireguard"]),
|
||||
newtId: z.string().min(1).max(255).optional(),
|
||||
newtSecret: z.string().min(1).max(255).optional(),
|
||||
exitNodeId: z.int().positive().optional(),
|
||||
pubKey: z.string().optional(),
|
||||
subnet: z.string().optional(),
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/re-key/{siteId}/regenerate-site-secret",
|
||||
description: "Regenerate a site's Newt or WireGuard credentials by its site ID.",
|
||||
tags: [OpenAPITags.Site],
|
||||
request: {
|
||||
params: updateSiteParamsSchema,
|
||||
body: {
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: updateSiteBodySchema,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
responses: {},
|
||||
type: z.enum(["newt", "wireguard"]),
|
||||
secret: z.string().min(1).max(255).optional(),
|
||||
pubKey: z.string().optional(),
|
||||
disconnect: z.boolean().optional().default(true)
|
||||
});
|
||||
|
||||
export async function reGenerateSiteSecret(
|
||||
@@ -65,74 +46,149 @@ export async function reGenerateSiteSecret(
|
||||
const parsedParams = updateSiteParamsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, fromError(parsedParams.error).toString())
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const parsedBody = updateSiteBodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, fromError(parsedBody.error).toString())
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { siteId } = parsedParams.data;
|
||||
const { type, exitNodeId, pubKey, subnet, newtId, newtSecret } = parsedBody.data;
|
||||
|
||||
let updatedSite = undefined;
|
||||
const { type, pubKey, secret, disconnect } = parsedBody.data;
|
||||
|
||||
let existingNewt: Newt | null = null;
|
||||
if (type === "newt") {
|
||||
if (!newtSecret) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "newtSecret is required for newt sites")
|
||||
);
|
||||
}
|
||||
|
||||
const secretHash = await hashPassword(newtSecret);
|
||||
|
||||
updatedSite = await db
|
||||
.update(newts)
|
||||
.set({
|
||||
newtId,
|
||||
secretHash,
|
||||
})
|
||||
.where(eq(newts.siteId, siteId))
|
||||
.returning();
|
||||
|
||||
logger.info(`Regenerated Newt credentials for site ${siteId}`);
|
||||
|
||||
} else if (type === "wireguard") {
|
||||
if (!pubKey) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Public key is required for wireguard sites")
|
||||
);
|
||||
}
|
||||
|
||||
if (!exitNodeId) {
|
||||
if (!secret) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Exit node ID is required for wireguard sites"
|
||||
"newtSecret is required for newt sites"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const secretHash = await hashPassword(secret);
|
||||
|
||||
// get the newt to verify it exists
|
||||
const existingNewts = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, siteId));
|
||||
|
||||
if (existingNewts.length === 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`No Newt found for site ID ${siteId}`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (existingNewts.length > 1) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
`Multiple Newts found for site ID ${siteId}`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
existingNewt = existingNewts[0];
|
||||
|
||||
// update the secret on the existing newt
|
||||
await db
|
||||
.update(newts)
|
||||
.set({
|
||||
secretHash
|
||||
})
|
||||
.where(eq(newts.newtId, existingNewts[0].newtId));
|
||||
|
||||
// Only disconnect if explicitly requested
|
||||
if (disconnect) {
|
||||
const payload = {
|
||||
type: `newt/wg/terminate`,
|
||||
data: {}
|
||||
};
|
||||
// Don't await this to prevent blocking the response
|
||||
sendToClient(existingNewts[0].newtId, payload).catch(
|
||||
(error) => {
|
||||
logger.error(
|
||||
"Failed to send termination message to newt:",
|
||||
error
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
disconnectClient(existingNewts[0].newtId).catch((error) => {
|
||||
logger.error(
|
||||
"Failed to disconnect newt after re-key:",
|
||||
error
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
logger.info(`Regenerated Newt credentials for site ${siteId}`);
|
||||
} else if (type === "wireguard") {
|
||||
if (!pubKey) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Public key is required for wireguard sites"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
updatedSite = await db.transaction(async (tx) => {
|
||||
await addPeer(exitNodeId, {
|
||||
const [site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
|
||||
if (!site) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Site with ID ${siteId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await db
|
||||
.update(sites)
|
||||
.set({ pubKey })
|
||||
.where(eq(sites.siteId, siteId));
|
||||
|
||||
if (!site) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Site with ID ${siteId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (site.exitNodeId && site.subnet) {
|
||||
await deletePeer(site.exitNodeId, site.pubKey!); // the old pubkey
|
||||
await addPeer(site.exitNodeId, {
|
||||
publicKey: pubKey,
|
||||
allowedIps: subnet ? [subnet] : [],
|
||||
allowedIps: await getAllowedIps(site.siteId)
|
||||
});
|
||||
const result = await tx
|
||||
.update(sites)
|
||||
.set({ pubKey })
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.returning();
|
||||
}
|
||||
|
||||
return result;
|
||||
});
|
||||
|
||||
logger.info(`Regenerated WireGuard credentials for site ${siteId}`);
|
||||
logger.info(
|
||||
`Regenerated WireGuard credentials for site ${siteId}`
|
||||
);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`Transaction failed while regenerating WireGuard secret for site ${siteId}`,
|
||||
@@ -148,17 +204,21 @@ export async function reGenerateSiteSecret(
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: updatedSite,
|
||||
data: {
|
||||
newtId: existingNewt ? existingNewt.newtId : undefined
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Credentials regenerated successfully",
|
||||
status: HttpCode.OK,
|
||||
status: HttpCode.OK
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
logger.error("Unexpected error in reGenerateSiteSecret", error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An unexpected error occurred")
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"An unexpected error occurred"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,9 +36,9 @@ export const paramsSchema = z.object({
|
||||
});
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
remoteExitNodeId: z.string().length(15),
|
||||
secret: z.string().length(48)
|
||||
});
|
||||
remoteExitNodeId: z.string().length(15),
|
||||
secret: z.string().length(48)
|
||||
});
|
||||
|
||||
export type CreateRemoteExitNodeBody = z.infer<typeof bodySchema>;
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ import { usageService } from "@server/lib/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string().min(1),
|
||||
remoteExitNodeId: z.string().min(1)
|
||||
});
|
||||
orgId: z.string().min(1),
|
||||
remoteExitNodeId: z.string().min(1)
|
||||
});
|
||||
|
||||
export async function deleteRemoteExitNode(
|
||||
req: Request,
|
||||
|
||||
@@ -24,9 +24,9 @@ import { fromError } from "zod-validation-error";
|
||||
import { GetRemoteExitNodeResponse } from "@server/routers/remoteExitNode/types";
|
||||
|
||||
const getRemoteExitNodeSchema = z.strictObject({
|
||||
orgId: z.string().min(1),
|
||||
remoteExitNodeId: z.string().min(1)
|
||||
});
|
||||
orgId: z.string().min(1),
|
||||
remoteExitNodeId: z.string().min(1)
|
||||
});
|
||||
|
||||
async function query(remoteExitNodeId: string) {
|
||||
const [remoteExitNode] = await db
|
||||
|
||||
@@ -55,7 +55,8 @@ export async function getRemoteExitNodeToken(
|
||||
|
||||
try {
|
||||
if (token) {
|
||||
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
|
||||
const { session, remoteExitNode } =
|
||||
await validateRemoteExitNodeSessionToken(token);
|
||||
if (session) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
@@ -103,7 +104,10 @@ export async function getRemoteExitNodeToken(
|
||||
}
|
||||
|
||||
const resToken = generateSessionToken();
|
||||
await createRemoteExitNodeSession(resToken, existingRemoteExitNode.remoteExitNodeId);
|
||||
await createRemoteExitNodeSession(
|
||||
resToken,
|
||||
existingRemoteExitNode.remoteExitNodeId
|
||||
);
|
||||
|
||||
// logger.debug(`Created RemoteExitNode token response: ${JSON.stringify(resToken)}`);
|
||||
|
||||
|
||||
@@ -33,7 +33,9 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
|
||||
|
||||
offlineCheckerInterval = setInterval(async () => {
|
||||
try {
|
||||
const twoMinutesAgo = Math.floor((Date.now() - OFFLINE_THRESHOLD_MS) / 1000);
|
||||
const twoMinutesAgo = Math.floor(
|
||||
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
|
||||
);
|
||||
|
||||
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
|
||||
const newlyOfflineNodes = await db
|
||||
@@ -48,11 +50,13 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
|
||||
isNull(exitNodes.lastPing)
|
||||
)
|
||||
)
|
||||
).returning();
|
||||
|
||||
)
|
||||
.returning();
|
||||
|
||||
// Update the sites to offline if they have not pinged either
|
||||
const exitNodeIds = newlyOfflineNodes.map(node => node.exitNodeId);
|
||||
const exitNodeIds = newlyOfflineNodes.map(
|
||||
(node) => node.exitNodeId
|
||||
);
|
||||
|
||||
const sitesOnNode = await db
|
||||
.select()
|
||||
@@ -63,10 +67,10 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
|
||||
inArray(sites.exitNodeId, exitNodeIds)
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
// loop through the sites and process their lastBandwidthUpdate as an iso string and if its more than 1 minute old then mark the site offline
|
||||
for (const site of sitesOnNode) {
|
||||
if (!site.lastBandwidthUpdate) {
|
||||
if (!site.lastBandwidthUpdate) {
|
||||
continue;
|
||||
}
|
||||
const lastBandwidthUpdate = new Date(site.lastBandwidthUpdate);
|
||||
@@ -77,13 +81,12 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
|
||||
.where(eq(sites.siteId, site.siteId));
|
||||
}
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
logger.error("Error in offline checker interval", { error });
|
||||
}
|
||||
}, OFFLINE_CHECK_INTERVAL);
|
||||
|
||||
logger.info("Started offline checker interval");
|
||||
logger.debug("Started offline checker interval");
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -100,7 +103,9 @@ export const stopRemoteExitNodeOfflineChecker = (): void => {
|
||||
/**
|
||||
* Handles ping messages from clients and responds with pong
|
||||
*/
|
||||
export const handleRemoteExitNodePingMessage: MessageHandler = async (context) => {
|
||||
export const handleRemoteExitNodePingMessage: MessageHandler = async (
|
||||
context
|
||||
) => {
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const remoteExitNode = c as RemoteExitNode;
|
||||
|
||||
@@ -120,7 +125,7 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
|
||||
.update(exitNodes)
|
||||
.set({
|
||||
lastPing: Math.floor(Date.now() / 1000),
|
||||
online: true,
|
||||
online: true
|
||||
})
|
||||
.where(eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId));
|
||||
} catch (error) {
|
||||
@@ -131,10 +136,10 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
|
||||
message: {
|
||||
type: "pong",
|
||||
data: {
|
||||
timestamp: new Date().toISOString(),
|
||||
timestamp: new Date().toISOString()
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
excludeSender: false
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
@@ -29,7 +29,8 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
|
||||
return;
|
||||
}
|
||||
|
||||
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } = message.data;
|
||||
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } =
|
||||
message.data;
|
||||
|
||||
if (!remoteExitNodeVersion) {
|
||||
logger.warn("Remote exit node version not found");
|
||||
@@ -39,7 +40,10 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
|
||||
// update the version
|
||||
await db
|
||||
.update(remoteExitNodes)
|
||||
.set({ version: remoteExitNodeVersion, secondaryVersion: remoteExitNodeSecondaryVersion })
|
||||
.set({
|
||||
version: remoteExitNodeVersion,
|
||||
secondaryVersion: remoteExitNodeSecondaryVersion
|
||||
})
|
||||
.where(
|
||||
eq(
|
||||
remoteExitNodes.remoteExitNodeId,
|
||||
|
||||
@@ -24,8 +24,8 @@ import { fromError } from "zod-validation-error";
|
||||
import { ListRemoteExitNodesResponse } from "@server/routers/remoteExitNode/types";
|
||||
|
||||
const listRemoteExitNodesParamsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const listRemoteExitNodesSchema = z.object({
|
||||
limit: z
|
||||
|
||||
@@ -22,8 +22,8 @@ import { z } from "zod";
|
||||
import { PickRemoteExitNodeDefaultsResponse } from "@server/routers/remoteExitNode/types";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export async function pickRemoteExitNodeDefaults(
|
||||
req: Request,
|
||||
|
||||
@@ -38,7 +38,9 @@ export async function quickStartRemoteExitNode(
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedBody = quickStartRemoteExitNodeBodySchema.safeParse(req.body);
|
||||
const parsedBody = quickStartRemoteExitNodeBodySchema.safeParse(
|
||||
req.body
|
||||
);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
|
||||
@@ -11,4 +11,4 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./ws";
|
||||
export * from "./ws";
|
||||
|
||||
@@ -23,4 +23,4 @@ export const messageHandlers: Record<string, MessageHandler> = {
|
||||
"remoteExitNode/ping": handleRemoteExitNodePingMessage
|
||||
};
|
||||
|
||||
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
|
||||
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
|
||||
|
||||
@@ -37,7 +37,14 @@ import { validateRemoteExitNodeSessionToken } from "#private/auth/sessions/remot
|
||||
import { rateLimitService } from "#private/lib/rateLimit";
|
||||
import { messageHandlers } from "@server/routers/ws/messageHandlers";
|
||||
import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers";
|
||||
import { AuthenticatedWebSocket, ClientType, WSMessage, TokenPayload, WebSocketRequest, RedisMessage } from "@server/routers/ws";
|
||||
import {
|
||||
AuthenticatedWebSocket,
|
||||
ClientType,
|
||||
WSMessage,
|
||||
TokenPayload,
|
||||
WebSocketRequest,
|
||||
RedisMessage
|
||||
} from "@server/routers/ws";
|
||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
|
||||
// Merge public and private message handlers
|
||||
@@ -55,9 +62,9 @@ const processMessage = async (
|
||||
try {
|
||||
const message: WSMessage = JSON.parse(data.toString());
|
||||
|
||||
logger.debug(
|
||||
`Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
|
||||
);
|
||||
// logger.debug(
|
||||
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
|
||||
// );
|
||||
|
||||
if (!message.type || typeof message.type !== "string") {
|
||||
throw new Error("Invalid message format: missing or invalid type");
|
||||
@@ -216,7 +223,7 @@ const initializeRedisSubscription = async (): Promise<void> => {
|
||||
// Each node is responsible for restoring its own connection state to Redis
|
||||
// This approach is more efficient than cross-node coordination because:
|
||||
// 1. Each node knows its own connections (source of truth)
|
||||
// 2. No network overhead from broadcasting state between nodes
|
||||
// 2. No network overhead from broadcasting state between nodes
|
||||
// 3. No race conditions from simultaneous updates
|
||||
// 4. Redis becomes eventually consistent as each node restores independently
|
||||
// 5. Simpler logic with better fault tolerance
|
||||
@@ -233,8 +240,10 @@ const recoverConnectionState = async (): Promise<void> => {
|
||||
// Each node simply restores its own local connections to Redis
|
||||
// This is the source of truth - no need for cross-node coordination
|
||||
await restoreLocalConnectionsToRedis();
|
||||
|
||||
logger.info("Redis connection state recovery completed - restored local state");
|
||||
|
||||
logger.info(
|
||||
"Redis connection state recovery completed - restored local state"
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Error during Redis recovery:", error);
|
||||
} finally {
|
||||
@@ -251,8 +260,10 @@ const restoreLocalConnectionsToRedis = async (): Promise<void> => {
|
||||
try {
|
||||
// Restore all current local connections to Redis
|
||||
for (const [clientId, clients] of connectedClients.entries()) {
|
||||
const validClients = clients.filter(client => client.readyState === WebSocket.OPEN);
|
||||
|
||||
const validClients = clients.filter(
|
||||
(client) => client.readyState === WebSocket.OPEN
|
||||
);
|
||||
|
||||
if (validClients.length > 0) {
|
||||
// Add this node to the client's connection list
|
||||
await redisManager.sadd(getConnectionsKey(clientId), NODE_ID);
|
||||
@@ -303,7 +314,10 @@ const addClient = async (
|
||||
Date.now().toString()
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to add client to Redis tracking (connection still functional locally):", error);
|
||||
logger.error(
|
||||
"Failed to add client to Redis tracking (connection still functional locally):",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,9 +340,14 @@ const removeClient = async (
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
|
||||
await redisManager.del(getNodeConnectionsKey(NODE_ID, clientId));
|
||||
await redisManager.del(
|
||||
getNodeConnectionsKey(NODE_ID, clientId)
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to remove client from Redis tracking (cleanup will occur on recovery):", error);
|
||||
logger.error(
|
||||
"Failed to remove client from Redis tracking (cleanup will occur on recovery):",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,7 +364,10 @@ const removeClient = async (
|
||||
ws.connectionId
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to remove specific connection from Redis tracking:", error);
|
||||
logger.error(
|
||||
"Failed to remove specific connection from Redis tracking:",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,7 +394,9 @@ const sendToClientLocal = async (
|
||||
}
|
||||
});
|
||||
|
||||
logger.debug(`sendToClient: Message type ${message.type} sent to clientId ${clientId}`);
|
||||
logger.debug(
|
||||
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
|
||||
);
|
||||
|
||||
return true;
|
||||
};
|
||||
@@ -411,14 +435,22 @@ const sendToClient = async (
|
||||
fromNodeId: NODE_ID
|
||||
};
|
||||
|
||||
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
|
||||
await redisManager.publish(
|
||||
REDIS_CHANNEL,
|
||||
JSON.stringify(redisMessage)
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to send message via Redis, message may be lost:", error);
|
||||
logger.error(
|
||||
"Failed to send message via Redis, message may be lost:",
|
||||
error
|
||||
);
|
||||
// Continue execution - local delivery already attempted
|
||||
}
|
||||
} else if (!localSent && !redisManager.isRedisEnabled()) {
|
||||
// Redis is disabled or unavailable - log that we couldn't deliver to remote nodes
|
||||
logger.debug(`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`);
|
||||
logger.debug(
|
||||
`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`
|
||||
);
|
||||
}
|
||||
|
||||
return localSent;
|
||||
@@ -441,13 +473,21 @@ const broadcastToAllExcept = async (
|
||||
fromNodeId: NODE_ID
|
||||
};
|
||||
|
||||
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
|
||||
await redisManager.publish(
|
||||
REDIS_CHANNEL,
|
||||
JSON.stringify(redisMessage)
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to broadcast message via Redis, remote nodes may not receive it:", error);
|
||||
logger.error(
|
||||
"Failed to broadcast message via Redis, remote nodes may not receive it:",
|
||||
error
|
||||
);
|
||||
// Continue execution - local broadcast already completed
|
||||
}
|
||||
} else {
|
||||
logger.debug("Redis unavailable - broadcast limited to local node only");
|
||||
logger.debug(
|
||||
"Redis unavailable - broadcast limited to local node only"
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -512,8 +552,10 @@ const verifyToken = async (
|
||||
return null;
|
||||
}
|
||||
|
||||
if (olm.userId) { // this is a user device and we need to check the user token
|
||||
const { session: userSession, user } = await validateSessionToken(userToken);
|
||||
if (olm.userId) {
|
||||
// this is a user device and we need to check the user token
|
||||
const { session: userSession, user } =
|
||||
await validateSessionToken(userToken);
|
||||
if (!userSession || !user) {
|
||||
return null;
|
||||
}
|
||||
@@ -668,7 +710,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
|
||||
url.searchParams.get("token") ||
|
||||
request.headers["sec-websocket-protocol"] ||
|
||||
"";
|
||||
const userToken = url.searchParams.get('userToken') || '';
|
||||
const userToken = url.searchParams.get("userToken") || "";
|
||||
let clientType = url.searchParams.get(
|
||||
"clientType"
|
||||
) as ClientType;
|
||||
@@ -690,7 +732,11 @@ const handleWSUpgrade = (server: HttpServer): void => {
|
||||
return;
|
||||
}
|
||||
|
||||
const tokenPayload = await verifyToken(token, clientType, userToken);
|
||||
const tokenPayload = await verifyToken(
|
||||
token,
|
||||
clientType,
|
||||
userToken
|
||||
);
|
||||
if (!tokenPayload) {
|
||||
logger.debug(
|
||||
"Unauthorized connection attempt: invalid token..."
|
||||
@@ -724,50 +770,68 @@ const handleWSUpgrade = (server: HttpServer): void => {
|
||||
// Add periodic connection state sync to handle Redis disconnections/reconnections
|
||||
const startPeriodicStateSync = (): void => {
|
||||
// Lightweight sync every 5 minutes - just restore our own state
|
||||
setInterval(async () => {
|
||||
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
|
||||
try {
|
||||
await restoreLocalConnectionsToRedis();
|
||||
logger.debug("Periodic connection state sync completed");
|
||||
} catch (error) {
|
||||
logger.error("Error during periodic connection state sync:", error);
|
||||
setInterval(
|
||||
async () => {
|
||||
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
|
||||
try {
|
||||
await restoreLocalConnectionsToRedis();
|
||||
logger.debug("Periodic connection state sync completed");
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
"Error during periodic connection state sync:",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 5 * 60 * 1000); // 5 minutes
|
||||
},
|
||||
5 * 60 * 1000
|
||||
); // 5 minutes
|
||||
|
||||
// Cleanup stale connections every 15 minutes
|
||||
setInterval(async () => {
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
await cleanupStaleConnections();
|
||||
logger.debug("Periodic connection cleanup completed");
|
||||
} catch (error) {
|
||||
logger.error("Error during periodic connection cleanup:", error);
|
||||
setInterval(
|
||||
async () => {
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
await cleanupStaleConnections();
|
||||
logger.debug("Periodic connection cleanup completed");
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
"Error during periodic connection cleanup:",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 15 * 60 * 1000); // 15 minutes
|
||||
},
|
||||
15 * 60 * 1000
|
||||
); // 15 minutes
|
||||
};
|
||||
|
||||
const cleanupStaleConnections = async (): Promise<void> => {
|
||||
if (!redisManager.isRedisEnabled()) return;
|
||||
|
||||
try {
|
||||
const nodeKeys = await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`) || [];
|
||||
|
||||
const nodeKeys =
|
||||
(await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`)) ||
|
||||
[];
|
||||
|
||||
for (const nodeKey of nodeKeys) {
|
||||
const connections = await redisManager.hgetall(nodeKey);
|
||||
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, '');
|
||||
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, "");
|
||||
const localClients = connectedClients.get(clientId) || [];
|
||||
const localConnectionIds = localClients
|
||||
.filter(client => client.readyState === WebSocket.OPEN)
|
||||
.map(client => client.connectionId)
|
||||
.filter((client) => client.readyState === WebSocket.OPEN)
|
||||
.map((client) => client.connectionId)
|
||||
.filter(Boolean);
|
||||
|
||||
// Remove Redis entries for connections that no longer exist locally
|
||||
for (const [connectionId, timestamp] of Object.entries(connections)) {
|
||||
for (const [connectionId, timestamp] of Object.entries(
|
||||
connections
|
||||
)) {
|
||||
if (!localConnectionIds.includes(connectionId)) {
|
||||
await redisManager.hdel(nodeKey, connectionId);
|
||||
logger.debug(`Cleaned up stale connection: ${connectionId} for client: ${clientId}`);
|
||||
logger.debug(
|
||||
`Cleaned up stale connection: ${connectionId} for client: ${clientId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -776,7 +840,9 @@ const cleanupStaleConnections = async (): Promise<void> => {
|
||||
if (Object.keys(remainingConnections).length === 0) {
|
||||
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
|
||||
await redisManager.del(nodeKey);
|
||||
logger.debug(`Cleaned up empty connection tracking for client: ${clientId}`);
|
||||
logger.debug(
|
||||
`Cleaned up empty connection tracking for client: ${clientId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -789,38 +855,38 @@ if (redisManager.isRedisEnabled()) {
|
||||
initializeRedisSubscription().catch((error) => {
|
||||
logger.error("Failed to initialize Redis subscription:", error);
|
||||
});
|
||||
|
||||
|
||||
// Register recovery callback with Redis manager
|
||||
// When Redis reconnects, each node simply restores its own local state
|
||||
redisManager.onReconnection(async () => {
|
||||
logger.info("Redis reconnected, starting WebSocket state recovery...");
|
||||
await recoverConnectionState();
|
||||
});
|
||||
|
||||
|
||||
// Start periodic state synchronization
|
||||
startPeriodicStateSync();
|
||||
|
||||
|
||||
logger.info(
|
||||
`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`
|
||||
);
|
||||
} else {
|
||||
logger.debug(
|
||||
"WebSocket handler initialized in local mode"
|
||||
);
|
||||
logger.debug("WebSocket handler initialized in local mode");
|
||||
}
|
||||
|
||||
// Disconnect a specific client and force them to reconnect
|
||||
const disconnectClient = async (clientId: string): Promise<boolean> => {
|
||||
const mapKey = getClientMapKey(clientId);
|
||||
const clients = connectedClients.get(mapKey);
|
||||
|
||||
|
||||
if (!clients || clients.length === 0) {
|
||||
logger.debug(`No connections found for client ID: ${clientId}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
logger.info(`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`);
|
||||
|
||||
logger.info(
|
||||
`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`
|
||||
);
|
||||
|
||||
// Close all connections for this client
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
|
||||
Reference in New Issue
Block a user