Format all files

This commit is contained in:
Owen
2025-12-09 10:56:14 -05:00
parent fa839a811f
commit f9b03943c3
535 changed files with 7670 additions and 5626 deletions

View File

@@ -11,11 +11,14 @@
* This file is not licensed under the AGPLv3.
*/
import {
encodeHexLowerCase,
} from "@oslojs/encoding";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { RemoteExitNode, remoteExitNodes, remoteExitNodeSessions, RemoteExitNodeSession } from "@server/db";
import {
RemoteExitNode,
remoteExitNodes,
remoteExitNodeSessions,
RemoteExitNodeSession
} from "@server/db";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
@@ -23,30 +26,39 @@ export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
export async function createRemoteExitNodeSession(
token: string,
remoteExitNodeId: string,
remoteExitNodeId: string
): Promise<RemoteExitNodeSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const session: RemoteExitNodeSession = {
sessionId: sessionId,
remoteExitNodeId,
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
expiresAt: new Date(Date.now() + EXPIRES).getTime()
};
await db.insert(remoteExitNodeSessions).values(session);
return session;
}
export async function validateRemoteExitNodeSessionToken(
token: string,
token: string
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
sha256(new TextEncoder().encode(token))
);
const result = await db
.select({ remoteExitNode: remoteExitNodes, session: remoteExitNodeSessions })
.select({
remoteExitNode: remoteExitNodes,
session: remoteExitNodeSessions
})
.from(remoteExitNodeSessions)
.innerJoin(remoteExitNodes, eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodes.remoteExitNodeId))
.innerJoin(
remoteExitNodes,
eq(
remoteExitNodeSessions.remoteExitNodeId,
remoteExitNodes.remoteExitNodeId
)
)
.where(eq(remoteExitNodeSessions.sessionId, sessionId));
if (result.length < 1) {
return { session: null, remoteExitNode: null };
@@ -58,26 +70,32 @@ export async function validateRemoteExitNodeSessionToken(
.where(eq(remoteExitNodeSessions.sessionId, session.sessionId));
return { session: null, remoteExitNode: null };
}
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + EXPIRES,
).getTime();
if (Date.now() >= session.expiresAt - EXPIRES / 2) {
session.expiresAt = new Date(Date.now() + EXPIRES).getTime();
await db
.update(remoteExitNodeSessions)
.set({
expiresAt: session.expiresAt,
expiresAt: session.expiresAt
})
.where(eq(remoteExitNodeSessions.sessionId, session.sessionId));
}
return { session, remoteExitNode };
}
export async function invalidateRemoteExitNodeSession(sessionId: string): Promise<void> {
await db.delete(remoteExitNodeSessions).where(eq(remoteExitNodeSessions.sessionId, sessionId));
export async function invalidateRemoteExitNodeSession(
sessionId: string
): Promise<void> {
await db
.delete(remoteExitNodeSessions)
.where(eq(remoteExitNodeSessions.sessionId, sessionId));
}
export async function invalidateAllRemoteExitNodeSessions(remoteExitNodeId: string): Promise<void> {
await db.delete(remoteExitNodeSessions).where(eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodeId));
export async function invalidateAllRemoteExitNodeSessions(
remoteExitNodeId: string
): Promise<void> {
await db
.delete(remoteExitNodeSessions)
.where(eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodeId));
}
export type SessionValidationResult =

View File

@@ -25,4 +25,4 @@ export async function initCleanup() {
// Handle process termination
process.on("SIGTERM", () => cleanup());
process.on("SIGINT", () => cleanup());
}
}

View File

@@ -12,4 +12,4 @@
*/
export * from "./getOrgTierData";
export * from "./createCustomer";
export * from "./createCustomer";

View File

@@ -55,7 +55,6 @@ export async function getValidCertificatesForDomains(
domains: Set<string>,
useCache: boolean = true
): Promise<Array<CertificateResult>> {
loadEncryptData(); // Ensure encryption key is loaded
const finalResults: CertificateResult[] = [];

View File

@@ -12,14 +12,7 @@
*/
import { build } from "@server/build";
import {
db,
Org,
orgs,
ResourceSession,
sessions,
users
} from "@server/db";
import { db, Org, orgs, ResourceSession, sessions, users } from "@server/db";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import license from "#private/license/license";

View File

@@ -66,7 +66,9 @@ export async function sendToExitNode(
// logger.debug(`Configured local exit node name: ${config.getRawConfig().gerbil.exit_node_name}`);
if (exitNode.name == config.getRawConfig().gerbil.exit_node_name) {
hostname = privateConfig.getRawPrivateConfig().gerbil.local_exit_node_reachable_at;
hostname =
privateConfig.getRawPrivateConfig().gerbil
.local_exit_node_reachable_at;
}
if (!hostname) {

View File

@@ -44,43 +44,53 @@ async function checkExitNodeOnlineStatus(
const delayBetweenAttempts = 100; // 100ms delay between starting each attempt
// Create promises for all attempts with staggered delays
const attemptPromises = Array.from({ length: maxAttempts }, async (_, index) => {
const attemptNumber = index + 1;
// Add delay before each attempt (except the first)
if (index > 0) {
await new Promise((resolve) => setTimeout(resolve, delayBetweenAttempts * index));
}
const attemptPromises = Array.from(
{ length: maxAttempts },
async (_, index) => {
const attemptNumber = index + 1;
try {
const response = await axios.get(`http://${endpoint}/ping`, {
timeout: timeoutMs,
validateStatus: (status) => status === 200
});
if (response.status === 200) {
logger.debug(
`Exit node ${endpoint} is online (attempt ${attemptNumber}/${maxAttempts})`
// Add delay before each attempt (except the first)
if (index > 0) {
await new Promise((resolve) =>
setTimeout(resolve, delayBetweenAttempts * index)
);
return { success: true, attemptNumber };
}
return { success: false, attemptNumber, error: 'Non-200 status' };
} catch (error) {
const errorMessage = error instanceof Error ? error.message : "Unknown error";
logger.debug(
`Exit node ${endpoint} ping failed (attempt ${attemptNumber}/${maxAttempts}): ${errorMessage}`
);
return { success: false, attemptNumber, error: errorMessage };
try {
const response = await axios.get(`http://${endpoint}/ping`, {
timeout: timeoutMs,
validateStatus: (status) => status === 200
});
if (response.status === 200) {
logger.debug(
`Exit node ${endpoint} is online (attempt ${attemptNumber}/${maxAttempts})`
);
return { success: true, attemptNumber };
}
return {
success: false,
attemptNumber,
error: "Non-200 status"
};
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
logger.debug(
`Exit node ${endpoint} ping failed (attempt ${attemptNumber}/${maxAttempts}): ${errorMessage}`
);
return { success: false, attemptNumber, error: errorMessage };
}
}
});
);
try {
// Wait for the first successful response or all to fail
const results = await Promise.allSettled(attemptPromises);
// Check if any attempt succeeded
for (const result of results) {
if (result.status === 'fulfilled' && result.value.success) {
if (result.status === "fulfilled" && result.value.success) {
return true;
}
}
@@ -137,7 +147,11 @@ export async function verifyExitNodeOrgAccess(
return { hasAccess: false, exitNode };
}
export async function listExitNodes(orgId: string, filterOnline = false, noCloud = false) {
export async function listExitNodes(
orgId: string,
filterOnline = false,
noCloud = false
) {
const allExitNodes = await db
.select({
exitNodeId: exitNodes.exitNodeId,
@@ -166,7 +180,10 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
eq(exitNodes.type, "gerbil"),
or(
// only choose nodes that are in the same region
eq(exitNodes.region, config.getRawPrivateConfig().app.region),
eq(
exitNodes.region,
config.getRawPrivateConfig().app.region
),
isNull(exitNodes.region) // or for enterprise where region is not set
)
),
@@ -191,7 +208,7 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
// let online: boolean;
// if (filterOnline && node.type == "remoteExitNode") {
// try {
// const isActuallyOnline = await checkExitNodeOnlineStatus(
// const isActuallyOnline = await checkExitNodeOnlineStatus(
// node.endpoint
// );
@@ -225,7 +242,8 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
node.type === "remoteExitNode" && (!filterOnline || node.online)
);
const gerbilExitNodes = allExitNodes.filter(
(node) => node.type === "gerbil" && (!filterOnline || node.online) && !noCloud
(node) =>
node.type === "gerbil" && (!filterOnline || node.online) && !noCloud
);
// THIS PROVIDES THE FALL
@@ -334,7 +352,11 @@ export function selectBestExitNode(
return fallbackNode;
}
export async function checkExitNodeOrg(exitNodeId: number, orgId: string, trx: Transaction | typeof db = db) {
export async function checkExitNodeOrg(
exitNodeId: number,
orgId: string,
trx: Transaction | typeof db = db
) {
const [exitNodeOrg] = await trx
.select()
.from(exitNodeOrgs)

View File

@@ -12,4 +12,4 @@
*/
export * from "./exitNodeComms";
export * from "./exitNodes";
export * from "./exitNodes";

View File

@@ -177,7 +177,9 @@ export class LockManager {
const exists = value !== null;
const ownedByMe =
exists &&
value!.startsWith(`${config.getRawConfig().gerbil.exit_node_name}:`);
value!.startsWith(
`${config.getRawConfig().gerbil.exit_node_name}:`
);
const owner = exists ? value!.split(":")[0] : undefined;
return {

View File

@@ -14,15 +14,15 @@
// Simple test file for the rate limit service with Redis
// Run with: npx ts-node rateLimitService.test.ts
import { RateLimitService } from './rateLimit';
import { RateLimitService } from "./rateLimit";
function generateClientId() {
return 'client-' + Math.random().toString(36).substring(2, 15);
return "client-" + Math.random().toString(36).substring(2, 15);
}
async function runTests() {
console.log('Starting Rate Limit Service Tests...\n');
console.log("Starting Rate Limit Service Tests...\n");
const rateLimitService = new RateLimitService();
let testsPassed = 0;
let testsTotal = 0;
@@ -47,36 +47,54 @@ async function runTests() {
}
// Test 1: Basic rate limiting
await test('Should allow requests under the limit', async () => {
await test("Should allow requests under the limit", async () => {
const clientId = generateClientId();
const maxRequests = 5;
for (let i = 0; i < maxRequests - 1; i++) {
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
const result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
assert(result.totalHits === i + 1, `Expected ${i + 1} hits, got ${result.totalHits}`);
assert(
result.totalHits === i + 1,
`Expected ${i + 1} hits, got ${result.totalHits}`
);
}
});
// Test 2: Rate limit blocking
await test('Should block requests over the limit', async () => {
await test("Should block requests over the limit", async () => {
const clientId = generateClientId();
const maxRequests = 30;
// Use up all allowed requests
for (let i = 0; i < maxRequests - 1; i++) {
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
const result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
}
// Next request should be blocked
const blockedResult = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(blockedResult.isLimited, 'Request should be blocked');
assert(blockedResult.reason === 'global', 'Should be blocked for global reason');
const blockedResult = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(blockedResult.isLimited, "Request should be blocked");
assert(
blockedResult.reason === "global",
"Should be blocked for global reason"
);
});
// Test 3: Message type limits
await test('Should handle message type limits', async () => {
await test("Should handle message type limits", async () => {
const clientId = generateClientId();
const globalMax = 10;
const messageTypeMax = 2;
@@ -84,54 +102,64 @@ async function runTests() {
// Send messages of type 'ping' up to the limit
for (let i = 0; i < messageTypeMax - 1; i++) {
const result = await rateLimitService.checkRateLimit(
clientId,
'ping',
globalMax,
clientId,
"ping",
globalMax,
messageTypeMax
);
assert(!result.isLimited, `Ping message ${i + 1} should be allowed`);
assert(
!result.isLimited,
`Ping message ${i + 1} should be allowed`
);
}
// Next 'ping' should be blocked
const blockedResult = await rateLimitService.checkRateLimit(
clientId,
'ping',
globalMax,
clientId,
"ping",
globalMax,
messageTypeMax
);
assert(blockedResult.isLimited, 'Ping message should be blocked');
assert(blockedResult.reason === 'message_type:ping', 'Should be blocked for message type');
assert(blockedResult.isLimited, "Ping message should be blocked");
assert(
blockedResult.reason === "message_type:ping",
"Should be blocked for message type"
);
// Other message types should still work
const otherResult = await rateLimitService.checkRateLimit(
clientId,
'pong',
globalMax,
clientId,
"pong",
globalMax,
messageTypeMax
);
assert(!otherResult.isLimited, 'Pong message should be allowed');
assert(!otherResult.isLimited, "Pong message should be allowed");
});
// Test 4: Reset functionality
await test('Should reset client correctly', async () => {
await test("Should reset client correctly", async () => {
const clientId = generateClientId();
const maxRequests = 3;
// Use up some requests
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
await rateLimitService.checkRateLimit(clientId, 'test', maxRequests);
await rateLimitService.checkRateLimit(clientId, "test", maxRequests);
// Reset the client
await rateLimitService.resetKey(clientId);
// Should be able to make fresh requests
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(!result.isLimited, 'Request after reset should be allowed');
assert(result.totalHits === 1, 'Should have 1 hit after reset');
const result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(!result.isLimited, "Request after reset should be allowed");
assert(result.totalHits === 1, "Should have 1 hit after reset");
});
// Test 5: Different clients are independent
await test('Should handle different clients independently', async () => {
await test("Should handle different clients independently", async () => {
const client1 = generateClientId();
const client2 = generateClientId();
const maxRequests = 2;
@@ -139,43 +167,62 @@ async function runTests() {
// Client 1 uses up their limit
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
const client1Blocked = await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
assert(client1Blocked.isLimited, 'Client 1 should be blocked');
const client1Blocked = await rateLimitService.checkRateLimit(
client1,
undefined,
maxRequests
);
assert(client1Blocked.isLimited, "Client 1 should be blocked");
// Client 2 should still be able to make requests
const client2Result = await rateLimitService.checkRateLimit(client2, undefined, maxRequests);
assert(!client2Result.isLimited, 'Client 2 should not be blocked');
assert(client2Result.totalHits === 1, 'Client 2 should have 1 hit');
const client2Result = await rateLimitService.checkRateLimit(
client2,
undefined,
maxRequests
);
assert(!client2Result.isLimited, "Client 2 should not be blocked");
assert(client2Result.totalHits === 1, "Client 2 should have 1 hit");
});
// Test 6: Decrement functionality
await test('Should decrement correctly', async () => {
await test("Should decrement correctly", async () => {
const clientId = generateClientId();
const maxRequests = 5;
// Make some requests
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
let result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(result.totalHits === 3, 'Should have 3 hits before decrement');
let result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(result.totalHits === 3, "Should have 3 hits before decrement");
// Decrement
await rateLimitService.decrementRateLimit(clientId);
// Next request should reflect the decrement
result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
assert(result.totalHits === 3, 'Should have 3 hits after decrement + increment');
result = await rateLimitService.checkRateLimit(
clientId,
undefined,
maxRequests
);
assert(
result.totalHits === 3,
"Should have 3 hits after decrement + increment"
);
});
// Wait a moment for any pending Redis operations
console.log('\nWaiting for Redis sync...');
await new Promise(resolve => setTimeout(resolve, 1000));
console.log("\nWaiting for Redis sync...");
await new Promise((resolve) => setTimeout(resolve, 1000));
// Force sync to test Redis integration
await test('Should sync to Redis', async () => {
await test("Should sync to Redis", async () => {
await rateLimitService.forceSyncAllPendingData();
// If this doesn't throw, Redis sync is working
assert(true, 'Redis sync completed');
assert(true, "Redis sync completed");
});
// Cleanup
@@ -185,18 +232,18 @@ async function runTests() {
console.log(`\n--- Test Results ---`);
console.log(`✅ Passed: ${testsPassed}/${testsTotal}`);
console.log(`❌ Failed: ${testsTotal - testsPassed}/${testsTotal}`);
if (testsPassed === testsTotal) {
console.log('\n🎉 All tests passed!');
console.log("\n🎉 All tests passed!");
process.exit(0);
} else {
console.log('\n💥 Some tests failed!');
console.log("\n💥 Some tests failed!");
process.exit(1);
}
}
// Run the tests
runTests().catch(error => {
console.error('Test runner error:', error);
runTests().catch((error) => {
console.error("Test runner error:", error);
process.exit(1);
});
});

View File

@@ -40,7 +40,8 @@ interface RateLimitResult {
export class RateLimitService {
private localRateLimitTracker: Map<string, RateLimitTracker> = new Map();
private localMessageTypeRateLimitTracker: Map<string, RateLimitTracker> = new Map();
private localMessageTypeRateLimitTracker: Map<string, RateLimitTracker> =
new Map();
private cleanupInterval: NodeJS.Timeout | null = null;
private forceSyncInterval: NodeJS.Timeout | null = null;
@@ -68,12 +69,18 @@ export class RateLimitService {
return `ratelimit:${clientId}`;
}
private getMessageTypeRateLimitKey(clientId: string, messageType: string): string {
private getMessageTypeRateLimitKey(
clientId: string,
messageType: string
): string {
return `ratelimit:${clientId}:${messageType}`;
}
// Helper function to clean up old timestamp fields from a Redis hash
private async cleanupOldTimestamps(key: string, windowStart: number): Promise<void> {
private async cleanupOldTimestamps(
key: string,
windowStart: number
): Promise<void> {
if (!redisManager.isRedisEnabled()) return;
try {
@@ -101,10 +108,15 @@ export class RateLimitService {
const batch = fieldsToDelete.slice(i, i + batchSize);
await client.hdel(key, ...batch);
}
logger.debug(`Cleaned up ${fieldsToDelete.length} old timestamp fields from ${key}`);
logger.debug(
`Cleaned up ${fieldsToDelete.length} old timestamp fields from ${key}`
);
}
} catch (error) {
logger.error(`Failed to cleanup old timestamps for key ${key}:`, error);
logger.error(
`Failed to cleanup old timestamps for key ${key}:`,
error
);
// Don't throw - cleanup failures shouldn't block rate limiting
}
}
@@ -114,7 +126,8 @@ export class RateLimitService {
clientId: string,
tracker: RateLimitTracker
): Promise<void> {
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0)
return;
try {
const currentTime = Math.floor(Date.now() / 1000);
@@ -132,7 +145,11 @@ export class RateLimitService {
const newValue = (
parseInt(currentValue || "0") + tracker.pendingCount
).toString();
await redisManager.hset(globalKey, currentTime.toString(), newValue);
await redisManager.hset(
globalKey,
currentTime.toString(),
newValue
);
// Set TTL using the client directly - this prevents the key from persisting forever
if (redisManager.getClient()) {
@@ -145,7 +162,9 @@ export class RateLimitService {
tracker.lastSyncedCount = tracker.count;
tracker.pendingCount = 0;
logger.debug(`Synced global rate limit to Redis for client ${clientId}`);
logger.debug(
`Synced global rate limit to Redis for client ${clientId}`
);
} catch (error) {
logger.error("Failed to sync global rate limit to Redis:", error);
}
@@ -156,12 +175,16 @@ export class RateLimitService {
messageType: string,
tracker: RateLimitTracker
): Promise<void> {
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0)
return;
try {
const currentTime = Math.floor(Date.now() / 1000);
const windowStart = currentTime - RATE_LIMIT_WINDOW;
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
const messageTypeKey = this.getMessageTypeRateLimitKey(
clientId,
messageType
);
// Clean up old timestamp fields before writing
await this.cleanupOldTimestamps(messageTypeKey, windowStart);
@@ -195,12 +218,17 @@ export class RateLimitService {
`Synced message type rate limit to Redis for client ${clientId}, type ${messageType}`
);
} catch (error) {
logger.error("Failed to sync message type rate limit to Redis:", error);
logger.error(
"Failed to sync message type rate limit to Redis:",
error
);
}
}
// Initialize local tracker from Redis data
private async initializeLocalTracker(clientId: string): Promise<RateLimitTracker> {
private async initializeLocalTracker(
clientId: string
): Promise<RateLimitTracker> {
const currentTime = Math.floor(Date.now() / 1000);
const windowStart = currentTime - RATE_LIMIT_WINDOW;
@@ -215,14 +243,16 @@ export class RateLimitService {
try {
const globalKey = this.getRateLimitKey(clientId);
// Clean up old timestamp fields before reading
await this.cleanupOldTimestamps(globalKey, windowStart);
const globalRateLimitData = await redisManager.hgetall(globalKey);
let count = 0;
for (const [timestamp, countStr] of Object.entries(globalRateLimitData)) {
for (const [timestamp, countStr] of Object.entries(
globalRateLimitData
)) {
const time = parseInt(timestamp);
if (time >= windowStart) {
count += parseInt(countStr);
@@ -236,7 +266,10 @@ export class RateLimitService {
lastSyncedCount: count
};
} catch (error) {
logger.error("Failed to initialize global tracker from Redis:", error);
logger.error(
"Failed to initialize global tracker from Redis:",
error
);
return {
count: 0,
windowStart: currentTime,
@@ -263,15 +296,21 @@ export class RateLimitService {
}
try {
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
const messageTypeKey = this.getMessageTypeRateLimitKey(
clientId,
messageType
);
// Clean up old timestamp fields before reading
await this.cleanupOldTimestamps(messageTypeKey, windowStart);
const messageTypeRateLimitData = await redisManager.hgetall(messageTypeKey);
const messageTypeRateLimitData =
await redisManager.hgetall(messageTypeKey);
let count = 0;
for (const [timestamp, countStr] of Object.entries(messageTypeRateLimitData)) {
for (const [timestamp, countStr] of Object.entries(
messageTypeRateLimitData
)) {
const time = parseInt(timestamp);
if (time >= windowStart) {
count += parseInt(countStr);
@@ -285,7 +324,10 @@ export class RateLimitService {
lastSyncedCount: count
};
} catch (error) {
logger.error("Failed to initialize message type tracker from Redis:", error);
logger.error(
"Failed to initialize message type tracker from Redis:",
error
);
return {
count: 0,
windowStart: currentTime,
@@ -327,7 +369,10 @@ export class RateLimitService {
isLimited: true,
reason: "global",
totalHits: globalTracker.count,
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
resetTime: new Date(
(globalTracker.windowStart + Math.floor(windowMs / 1000)) *
1000
)
};
}
@@ -339,19 +384,32 @@ export class RateLimitService {
// Check message type specific rate limit if messageType is provided
if (messageType) {
const messageTypeKey = `${clientId}:${messageType}`;
let messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
let messageTypeTracker =
this.localMessageTypeRateLimitTracker.get(messageTypeKey);
if (!messageTypeTracker || messageTypeTracker.windowStart < windowStart) {
if (
!messageTypeTracker ||
messageTypeTracker.windowStart < windowStart
) {
// New window or first request for this message type - initialize from Redis if available
messageTypeTracker = await this.initializeMessageTypeTracker(clientId, messageType);
messageTypeTracker = await this.initializeMessageTypeTracker(
clientId,
messageType
);
messageTypeTracker.windowStart = currentTime;
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
this.localMessageTypeRateLimitTracker.set(
messageTypeKey,
messageTypeTracker
);
}
// Increment message type counters
messageTypeTracker.count++;
messageTypeTracker.pendingCount++;
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
this.localMessageTypeRateLimitTracker.set(
messageTypeKey,
messageTypeTracker
);
// Check if message type limit would be exceeded
if (messageTypeTracker.count >= messageTypeLimit) {
@@ -359,25 +417,38 @@ export class RateLimitService {
isLimited: true,
reason: `message_type:${messageType}`,
totalHits: messageTypeTracker.count,
resetTime: new Date((messageTypeTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
resetTime: new Date(
(messageTypeTracker.windowStart +
Math.floor(windowMs / 1000)) *
1000
)
};
}
// Sync to Redis if threshold reached
if (messageTypeTracker.pendingCount >= REDIS_SYNC_THRESHOLD) {
this.syncMessageTypeRateLimitToRedis(clientId, messageType, messageTypeTracker);
this.syncMessageTypeRateLimitToRedis(
clientId,
messageType,
messageTypeTracker
);
}
}
return {
isLimited: false,
totalHits: globalTracker.count,
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
resetTime: new Date(
(globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000
)
};
}
// Decrement function for skipSuccessfulRequests/skipFailedRequests functionality
async decrementRateLimit(clientId: string, messageType?: string): Promise<void> {
async decrementRateLimit(
clientId: string,
messageType?: string
): Promise<void> {
// Decrement global counter
const globalTracker = this.localRateLimitTracker.get(clientId);
if (globalTracker && globalTracker.count > 0) {
@@ -389,7 +460,8 @@ export class RateLimitService {
// Decrement message type counter if provided
if (messageType) {
const messageTypeKey = `${clientId}:${messageType}`;
const messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
const messageTypeTracker =
this.localMessageTypeRateLimitTracker.get(messageTypeKey);
if (messageTypeTracker && messageTypeTracker.count > 0) {
messageTypeTracker.count--;
messageTypeTracker.pendingCount--;
@@ -401,7 +473,7 @@ export class RateLimitService {
async resetKey(clientId: string): Promise<void> {
// Remove from local tracking
this.localRateLimitTracker.delete(clientId);
// Remove all message type entries for this client
for (const [key] of this.localMessageTypeRateLimitTracker) {
if (key.startsWith(`${clientId}:`)) {
@@ -417,9 +489,13 @@ export class RateLimitService {
// Get all message type keys for this client and delete them
const client = redisManager.getClient();
if (client) {
const messageTypeKeys = await client.keys(`ratelimit:${clientId}:*`);
const messageTypeKeys = await client.keys(
`ratelimit:${clientId}:*`
);
if (messageTypeKeys.length > 0) {
await Promise.all(messageTypeKeys.map(key => redisManager.del(key)));
await Promise.all(
messageTypeKeys.map((key) => redisManager.del(key))
);
}
}
}
@@ -431,7 +507,10 @@ export class RateLimitService {
const windowStart = currentTime - RATE_LIMIT_WINDOW;
// Clean up global rate limit tracking and sync pending data
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
for (const [
clientId,
tracker
] of this.localRateLimitTracker.entries()) {
if (tracker.windowStart < windowStart) {
// Sync any pending data before cleanup
if (tracker.pendingCount > 0) {
@@ -442,12 +521,19 @@ export class RateLimitService {
}
// Clean up message type rate limit tracking and sync pending data
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
for (const [
key,
tracker
] of this.localMessageTypeRateLimitTracker.entries()) {
if (tracker.windowStart < windowStart) {
// Sync any pending data before cleanup
if (tracker.pendingCount > 0) {
const [clientId, messageType] = key.split(":", 2);
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
await this.syncMessageTypeRateLimitToRedis(
clientId,
messageType,
tracker
);
}
this.localMessageTypeRateLimitTracker.delete(key);
}
@@ -461,17 +547,27 @@ export class RateLimitService {
logger.debug("Force syncing all pending rate limit data to Redis...");
// Sync all pending global rate limits
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
for (const [
clientId,
tracker
] of this.localRateLimitTracker.entries()) {
if (tracker.pendingCount > 0) {
await this.syncRateLimitToRedis(clientId, tracker);
}
}
// Sync all pending message type rate limits
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
for (const [
key,
tracker
] of this.localMessageTypeRateLimitTracker.entries()) {
if (tracker.pendingCount > 0) {
const [clientId, messageType] = key.split(":", 2);
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
await this.syncMessageTypeRateLimitToRedis(
clientId,
messageType,
tracker
);
}
}
@@ -504,4 +600,4 @@ export class RateLimitService {
}
// Export singleton instance
export const rateLimitService = new RateLimitService();
export const rateLimitService = new RateLimitService();

View File

@@ -17,7 +17,10 @@ import { MemoryStore, Store } from "express-rate-limit";
import RedisStore from "#private/lib/redisStore";
export function createStore(): Store {
if (build != "oss" && privateConfig.getRawPrivateConfig().flags.enable_redis) {
if (
build != "oss" &&
privateConfig.getRawPrivateConfig().flags.enable_redis
) {
const rateLimitStore: Store = new RedisStore({
prefix: "api-rate-limit", // Optional: customize Redis key prefix
skipFailedRequests: true, // Don't count failed requests

View File

@@ -19,7 +19,7 @@ import { build } from "@server/build";
class RedisManager {
public client: Redis | null = null;
private writeClient: Redis | null = null; // Master for writes
private readClient: Redis | null = null; // Replica for reads
private readClient: Redis | null = null; // Replica for reads
private subscriber: Redis | null = null;
private publisher: Redis | null = null;
private isEnabled: boolean = false;
@@ -46,7 +46,8 @@ class RedisManager {
this.isEnabled = false;
return;
}
this.isEnabled = privateConfig.getRawPrivateConfig().flags.enable_redis || false;
this.isEnabled =
privateConfig.getRawPrivateConfig().flags.enable_redis || false;
if (this.isEnabled) {
this.initializeClients();
}
@@ -63,15 +64,19 @@ class RedisManager {
}
private async triggerReconnectionCallbacks(): Promise<void> {
logger.info(`Triggering ${this.reconnectionCallbacks.size} reconnection callbacks`);
const promises = Array.from(this.reconnectionCallbacks).map(async (callback) => {
try {
await callback();
} catch (error) {
logger.error("Error in reconnection callback:", error);
logger.info(
`Triggering ${this.reconnectionCallbacks.size} reconnection callbacks`
);
const promises = Array.from(this.reconnectionCallbacks).map(
async (callback) => {
try {
await callback();
} catch (error) {
logger.error("Error in reconnection callback:", error);
}
}
});
);
await Promise.allSettled(promises);
}
@@ -79,13 +84,17 @@ class RedisManager {
private async resubscribeToChannels(): Promise<void> {
if (!this.subscriber || this.subscribers.size === 0) return;
logger.info(`Re-subscribing to ${this.subscribers.size} channels after Redis reconnection`);
logger.info(
`Re-subscribing to ${this.subscribers.size} channels after Redis reconnection`
);
try {
const channels = Array.from(this.subscribers.keys());
if (channels.length > 0) {
await this.subscriber.subscribe(...channels);
logger.info(`Successfully re-subscribed to channels: ${channels.join(', ')}`);
logger.info(
`Successfully re-subscribed to channels: ${channels.join(", ")}`
);
}
} catch (error) {
logger.error("Failed to re-subscribe to channels:", error);
@@ -98,7 +107,7 @@ class RedisManager {
host: redisConfig.host!,
port: redisConfig.port!,
password: redisConfig.password,
db: redisConfig.db,
db: redisConfig.db
// tls: {
// rejectUnauthorized:
// redisConfig.tls?.reject_unauthorized || false
@@ -112,7 +121,7 @@ class RedisManager {
if (!redisConfig.replicas || redisConfig.replicas.length === 0) {
return null;
}
// Use the first replica for simplicity
// In production, you might want to implement load balancing across replicas
const replica = redisConfig.replicas[0];
@@ -120,7 +129,7 @@ class RedisManager {
host: replica.host!,
port: replica.port!,
password: replica.password,
db: replica.db || redisConfig.db,
db: replica.db || redisConfig.db
// tls: {
// rejectUnauthorized:
// replica.tls?.reject_unauthorized || false
@@ -133,7 +142,7 @@ class RedisManager {
private initializeClients(): void {
const masterConfig = this.getRedisConfig();
const replicaConfig = this.getReplicaRedisConfig();
this.hasReplicas = replicaConfig !== null;
try {
@@ -144,7 +153,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
// Initialize replica connection for reads (if available)
@@ -155,7 +164,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
} else {
// Fallback to master for reads if no replicas
@@ -172,7 +181,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
// Subscriber uses replica if available (reads)
@@ -182,7 +191,7 @@ class RedisManager {
maxRetriesPerRequest: 3,
keepAlive: 30000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout,
commandTimeout: this.commandTimeout
});
// Add reconnection handlers for write client
@@ -202,11 +211,14 @@ class RedisManager {
logger.info("Redis write client ready");
this.isWriteHealthy = true;
this.updateOverallHealth();
// Trigger reconnection callbacks when Redis comes back online
if (this.isHealthy) {
this.triggerReconnectionCallbacks().catch(error => {
logger.error("Error triggering reconnection callbacks:", error);
this.triggerReconnectionCallbacks().catch((error) => {
logger.error(
"Error triggering reconnection callbacks:",
error
);
});
}
});
@@ -233,11 +245,14 @@ class RedisManager {
logger.info("Redis read client ready");
this.isReadHealthy = true;
this.updateOverallHealth();
// Trigger reconnection callbacks when Redis comes back online
if (this.isHealthy) {
this.triggerReconnectionCallbacks().catch(error => {
logger.error("Error triggering reconnection callbacks:", error);
this.triggerReconnectionCallbacks().catch((error) => {
logger.error(
"Error triggering reconnection callbacks:",
error
);
});
}
});
@@ -298,8 +313,8 @@ class RedisManager {
}
);
const setupMessage = this.hasReplicas
? "Redis clients initialized successfully with replica support"
const setupMessage = this.hasReplicas
? "Redis clients initialized successfully with replica support"
: "Redis clients initialized successfully (single instance)";
logger.info(setupMessage);
@@ -313,7 +328,8 @@ class RedisManager {
private updateOverallHealth(): void {
// Overall health is true if write is healthy and (read is healthy OR we don't have replicas)
this.isHealthy = this.isWriteHealthy && (this.isReadHealthy || !this.hasReplicas);
this.isHealthy =
this.isWriteHealthy && (this.isReadHealthy || !this.hasReplicas);
}
private async executeWithRetry<T>(
@@ -322,49 +338,61 @@ class RedisManager {
fallbackOperation?: () => Promise<T>
): Promise<T> {
let lastError: Error | null = null;
for (let attempt = 0; attempt <= this.maxRetries; attempt++) {
try {
return await operation();
} catch (error) {
lastError = error as Error;
// If this is the last attempt, try fallback if available
if (attempt === this.maxRetries && fallbackOperation) {
try {
logger.warn(`${operationName} primary operation failed, trying fallback`);
logger.warn(
`${operationName} primary operation failed, trying fallback`
);
return await fallbackOperation();
} catch (fallbackError) {
logger.error(`${operationName} fallback also failed:`, fallbackError);
logger.error(
`${operationName} fallback also failed:`,
fallbackError
);
throw lastError;
}
}
// Don't retry on the last attempt
if (attempt === this.maxRetries) {
break;
}
// Calculate delay with exponential backoff
const delay = Math.min(
this.baseRetryDelay * Math.pow(this.backoffMultiplier, attempt),
this.baseRetryDelay *
Math.pow(this.backoffMultiplier, attempt),
this.maxRetryDelay
);
logger.warn(`${operationName} failed (attempt ${attempt + 1}/${this.maxRetries + 1}), retrying in ${delay}ms:`, error);
logger.warn(
`${operationName} failed (attempt ${attempt + 1}/${this.maxRetries + 1}), retrying in ${delay}ms:`,
error
);
// Wait before retrying
await new Promise(resolve => setTimeout(resolve, delay));
await new Promise((resolve) => setTimeout(resolve, delay));
}
}
logger.error(`${operationName} failed after ${this.maxRetries + 1} attempts:`, lastError);
logger.error(
`${operationName} failed after ${this.maxRetries + 1} attempts:`,
lastError
);
throw lastError;
}
private startHealthMonitoring(): void {
if (!this.isEnabled) return;
// Check health every 30 seconds
setInterval(async () => {
try {
@@ -381,7 +409,7 @@ class RedisManager {
private async checkRedisHealth(): Promise<boolean> {
const now = Date.now();
// Only check health every 30 seconds
if (now - this.lastHealthCheck < this.healthCheckInterval) {
return this.isHealthy;
@@ -400,24 +428,45 @@ class RedisManager {
// Check write client (master) health
await Promise.race([
this.writeClient.ping(),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Write client health check timeout')), 2000)
new Promise((_, reject) =>
setTimeout(
() =>
reject(
new Error("Write client health check timeout")
),
2000
)
)
]);
this.isWriteHealthy = true;
// Check read client health if it's different from write client
if (this.hasReplicas && this.readClient && this.readClient !== this.writeClient) {
if (
this.hasReplicas &&
this.readClient &&
this.readClient !== this.writeClient
) {
try {
await Promise.race([
this.readClient.ping(),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Read client health check timeout')), 2000)
new Promise((_, reject) =>
setTimeout(
() =>
reject(
new Error(
"Read client health check timeout"
)
),
2000
)
)
]);
this.isReadHealthy = true;
} catch (error) {
logger.error("Redis read client health check failed:", error);
logger.error(
"Redis read client health check failed:",
error
);
this.isReadHealthy = false;
}
} else {
@@ -475,16 +524,13 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.writeClient) return false;
try {
await this.executeWithRetry(
async () => {
if (ttl) {
await this.writeClient!.setex(key, ttl, value);
} else {
await this.writeClient!.set(key, value);
}
},
"Redis SET"
);
await this.executeWithRetry(async () => {
if (ttl) {
await this.writeClient!.setex(key, ttl, value);
} else {
await this.writeClient!.set(key, value);
}
}, "Redis SET");
return true;
} catch (error) {
logger.error("Redis SET error:", error);
@@ -496,9 +542,10 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return null;
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
? () => this.writeClient!.get(key)
: undefined;
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.get(key)
: undefined;
return await this.executeWithRetry(
() => this.readClient!.get(key),
@@ -560,9 +607,10 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return [];
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
? () => this.writeClient!.smembers(key)
: undefined;
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.smembers(key)
: undefined;
return await this.executeWithRetry(
() => this.readClient!.smembers(key),
@@ -598,9 +646,10 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return null;
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
? () => this.writeClient!.hget(key, field)
: undefined;
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.hget(key, field)
: undefined;
return await this.executeWithRetry(
() => this.readClient!.hget(key, field),
@@ -632,9 +681,10 @@ class RedisManager {
if (!this.isRedisEnabled() || !this.readClient) return {};
try {
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
? () => this.writeClient!.hgetall(key)
: undefined;
const fallbackOperation =
this.hasReplicas && this.writeClient && this.isWriteHealthy
? () => this.writeClient!.hgetall(key)
: undefined;
return await this.executeWithRetry(
() => this.readClient!.hgetall(key),
@@ -658,18 +708,18 @@ class RedisManager {
}
try {
await this.executeWithRetry(
async () => {
// Add timeout to prevent hanging
return Promise.race([
this.publisher!.publish(channel, message),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Redis publish timeout')), 3000)
await this.executeWithRetry(async () => {
// Add timeout to prevent hanging
return Promise.race([
this.publisher!.publish(channel, message),
new Promise((_, reject) =>
setTimeout(
() => reject(new Error("Redis publish timeout")),
3000
)
]);
},
"Redis PUBLISH"
);
)
]);
}, "Redis PUBLISH");
return true;
} catch (error) {
logger.error("Redis PUBLISH error:", error);
@@ -689,17 +739,20 @@ class RedisManager {
if (!this.subscribers.has(channel)) {
this.subscribers.set(channel, new Set());
// Only subscribe to the channel if it's the first subscriber
await this.executeWithRetry(
async () => {
return Promise.race([
this.subscriber!.subscribe(channel),
new Promise((_, reject) =>
setTimeout(() => reject(new Error('Redis subscribe timeout')), 5000)
await this.executeWithRetry(async () => {
return Promise.race([
this.subscriber!.subscribe(channel),
new Promise((_, reject) =>
setTimeout(
() =>
reject(
new Error("Redis subscribe timeout")
),
5000
)
]);
},
"Redis SUBSCRIBE"
);
)
]);
}, "Redis SUBSCRIBE");
}
this.subscribers.get(channel)!.add(callback);

View File

@@ -11,9 +11,9 @@
* This file is not licensed under the AGPLv3.
*/
import { Store, Options, IncrementResponse } from 'express-rate-limit';
import { rateLimitService } from './rateLimit';
import logger from '@server/logger';
import { Store, Options, IncrementResponse } from "express-rate-limit";
import { rateLimitService } from "./rateLimit";
import logger from "@server/logger";
/**
* A Redis-backed rate limiting store for express-rate-limit that optimizes
@@ -57,12 +57,14 @@ export default class RedisStore implements Store {
*
* @param options - Configuration options for the store.
*/
constructor(options: {
prefix?: string;
skipFailedRequests?: boolean;
skipSuccessfulRequests?: boolean;
} = {}) {
this.prefix = options.prefix || 'express-rate-limit';
constructor(
options: {
prefix?: string;
skipFailedRequests?: boolean;
skipSuccessfulRequests?: boolean;
} = {}
) {
this.prefix = options.prefix || "express-rate-limit";
this.skipFailedRequests = options.skipFailedRequests || false;
this.skipSuccessfulRequests = options.skipSuccessfulRequests || false;
}
@@ -101,7 +103,8 @@ export default class RedisStore implements Store {
return {
totalHits: result.totalHits || 1,
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
resetTime:
result.resetTime || new Date(Date.now() + this.windowMs)
};
} catch (error) {
logger.error(`RedisStore increment error for key ${key}:`, error);
@@ -158,7 +161,9 @@ export default class RedisStore implements Store {
*/
async resetAll(): Promise<void> {
try {
logger.warn('RedisStore resetAll called - this operation can be expensive');
logger.warn(
"RedisStore resetAll called - this operation can be expensive"
);
// Force sync all pending data first
await rateLimitService.forceSyncAllPendingData();
@@ -167,9 +172,9 @@ export default class RedisStore implements Store {
// scanning all Redis keys with our prefix, which could be expensive.
// In production, it's better to let entries expire naturally.
logger.info('RedisStore resetAll completed (pending data synced)');
logger.info("RedisStore resetAll completed (pending data synced)");
} catch (error) {
logger.error('RedisStore resetAll error:', error);
logger.error("RedisStore resetAll error:", error);
// Don't throw - this is an optional method
}
}
@@ -181,7 +186,9 @@ export default class RedisStore implements Store {
* @param key - The identifier for a client.
* @returns Current hit count and reset time, or null if no data exists.
*/
async getHits(key: string): Promise<{ totalHits: number; resetTime: Date } | null> {
async getHits(
key: string
): Promise<{ totalHits: number; resetTime: Date } | null> {
try {
const clientId = `${this.prefix}:${key}`;
@@ -200,7 +207,8 @@ export default class RedisStore implements Store {
return {
totalHits: Math.max(0, (result.totalHits || 0) - 1), // Adjust for the decrement
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
resetTime:
result.resetTime || new Date(Date.now() + this.windowMs)
};
} catch (error) {
logger.error(`RedisStore getHits error for key ${key}:`, error);
@@ -215,9 +223,9 @@ export default class RedisStore implements Store {
async shutdown(): Promise<void> {
try {
// The rateLimitService handles its own cleanup
logger.info('RedisStore shutdown completed');
logger.info("RedisStore shutdown completed");
} catch (error) {
logger.error('RedisStore shutdown error:', error);
logger.error("RedisStore shutdown error:", error);
}
}
}

View File

@@ -16,10 +16,10 @@ import privateConfig from "#private/lib/config";
import logger from "@server/logger";
export enum AudienceIds {
SignUps = "6c4e77b2-0851-4bd6-bac8-f51f91360f1a",
Subscribed = "870b43fd-387f-44de-8fc1-707335f30b20",
Churned = "f3ae92bd-2fdb-4d77-8746-2118afd62549",
Newsletter = "5500c431-191c-42f0-a5d4-8b6d445b4ea0"
SignUps = "6c4e77b2-0851-4bd6-bac8-f51f91360f1a",
Subscribed = "870b43fd-387f-44de-8fc1-707335f30b20",
Churned = "f3ae92bd-2fdb-4d77-8746-2118afd62549",
Newsletter = "5500c431-191c-42f0-a5d4-8b6d445b4ea0"
}
const resend = new Resend(
@@ -33,7 +33,9 @@ export async function moveEmailToAudience(
audienceId: AudienceIds
) {
if (process.env.ENVIRONMENT !== "prod") {
logger.debug(`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`);
logger.debug(
`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`
);
return;
}
const { error, data } = await retryWithBackoff(async () => {

View File

@@ -11,4 +11,4 @@
* This file is not licensed under the AGPLv3.
*/
export * from "./getTraefikConfig";
export * from "./getTraefikConfig";

View File

@@ -19,10 +19,7 @@ import * as crypto from "crypto";
* @param publicKey - The public key used for verification (PEM format)
* @returns The decoded payload if validation succeeds, throws an error otherwise
*/
function validateJWT<Payload>(
token: string,
publicKey: string
): Payload {
function validateJWT<Payload>(token: string, publicKey: string): Payload {
// Split the JWT into its three parts
const parts = token.split(".");
if (parts.length !== 3) {

View File

@@ -41,7 +41,11 @@ async function getActionDays(orgId: string): Promise<number> {
}
// store the result in cache
cache.set(`org_${orgId}_actionDays`, org.settingsLogRetentionDaysAction, 300);
cache.set(
`org_${orgId}_actionDays`,
org.settingsLogRetentionDaysAction,
300
);
return org.settingsLogRetentionDaysAction;
}
@@ -141,4 +145,3 @@ export function logActionAudit(action: ActionsEnum) {
}
};
}

View File

@@ -28,7 +28,8 @@ export async function verifyCertificateAccess(
try {
// Assume user/org access is already verified
const orgId = req.params.orgId;
const certId = req.params.certId || req.body?.certId || req.query?.certId;
const certId =
req.params.certId || req.body?.certId || req.query?.certId;
let domainId =
req.params.domainId || req.body?.domainId || req.query?.domainId;
@@ -39,10 +40,12 @@ export async function verifyCertificateAccess(
}
if (!domainId) {
if (!certId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Must provide either certId or domainId")
createHttpError(
HttpCode.BAD_REQUEST,
"Must provide either certId or domainId"
)
);
}
@@ -75,7 +78,10 @@ export async function verifyCertificateAccess(
if (!domainId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Must provide either certId or domainId")
createHttpError(
HttpCode.BAD_REQUEST,
"Must provide either certId or domainId"
)
);
}

View File

@@ -24,8 +24,7 @@ export async function verifyIdpAccess(
) {
try {
const userId = req.user!.userId;
const idpId =
req.params.idpId || req.body.idpId || req.query.idpId;
const idpId = req.params.idpId || req.body.idpId || req.query.idpId;
const orgId = req.params.orgId;
if (!userId) {
@@ -50,9 +49,7 @@ export async function verifyIdpAccess(
.select()
.from(idp)
.innerJoin(idpOrg, eq(idp.idpId, idpOrg.idpId))
.where(
and(eq(idp.idpId, idpId), eq(idpOrg.orgId, orgId))
)
.where(and(eq(idp.idpId, idpId), eq(idpOrg.orgId, orgId)))
.limit(1);
if (!idpRes || !idpRes.idp || !idpRes.idpOrg) {

View File

@@ -26,7 +26,8 @@ export const verifySessionRemoteExitNodeMiddleware = async (
// get the token from the auth header
const token = req.headers["authorization"]?.split(" ")[1] || "";
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
const { session, remoteExitNode } =
await validateRemoteExitNodeSessionToken(token);
if (!session || !remoteExitNode) {
if (config.getRawConfig().app.log_failed_attempts) {

View File

@@ -19,7 +19,11 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { queryAccessAuditLogsParams, queryAccessAuditLogsQuery, queryAccess } from "./queryAccessAuditLog";
import {
queryAccessAuditLogsParams,
queryAccessAuditLogsQuery,
queryAccess
} from "./queryAccessAuditLog";
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
registry.registerPath({
@@ -67,10 +71,13 @@ export async function exportAccessAuditLogs(
const log = await baseQuery.limit(data.limit).offset(data.offset);
const csvData = generateCSV(log);
res.setHeader('Content-Type', 'text/csv');
res.setHeader('Content-Disposition', `attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`);
res.setHeader("Content-Type", "text/csv");
res.setHeader(
"Content-Disposition",
`attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`
);
return res.send(csvData);
} catch (error) {
logger.error(error);
@@ -78,4 +85,4 @@ export async function exportAccessAuditLogs(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}
}

View File

@@ -19,7 +19,11 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { queryActionAuditLogsParams, queryActionAuditLogsQuery, queryAction } from "./queryActionAuditLog";
import {
queryActionAuditLogsParams,
queryActionAuditLogsQuery,
queryAction
} from "./queryActionAuditLog";
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
registry.registerPath({
@@ -60,17 +64,20 @@ export async function exportActionAuditLogs(
);
}
const data = { ...parsedQuery.data, ...parsedParams.data };
const data = { ...parsedQuery.data, ...parsedParams.data };
const baseQuery = queryAction(data);
const log = await baseQuery.limit(data.limit).offset(data.offset);
const csvData = generateCSV(log);
res.setHeader('Content-Type', 'text/csv');
res.setHeader('Content-Disposition', `attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`);
res.setHeader("Content-Type", "text/csv");
res.setHeader(
"Content-Disposition",
`attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`
);
return res.send(csvData);
} catch (error) {
logger.error(error);
@@ -78,4 +85,4 @@ export async function exportActionAuditLogs(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}
}

View File

@@ -14,4 +14,4 @@
export * from "./queryActionAuditLog";
export * from "./exportActionAuditLog";
export * from "./queryAccessAuditLog";
export * from "./exportAccessAuditLog";
export * from "./exportAccessAuditLog";

View File

@@ -44,7 +44,8 @@ export const queryAccessAuditLogsQuery = z.object({
.openapi({
type: "string",
format: "date-time",
description: "End time as ISO date string (defaults to current time)"
description:
"End time as ISO date string (defaults to current time)"
}),
action: z
.union([z.boolean(), z.string()])
@@ -181,9 +182,15 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
return {
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
resources: uniqueResources.filter((row): row is { id: number; name: string | null } => row.id !== null),
locations: uniqueLocations.map(row => row.locations).filter((location): location is string => location !== null)
actors: uniqueActors
.map((row) => row.actor)
.filter((actor): actor is string => actor !== null),
resources: uniqueResources.filter(
(row): row is { id: number; name: string | null } => row.id !== null
),
locations: uniqueLocations
.map((row) => row.locations)
.filter((location): location is string => location !== null)
};
}

View File

@@ -44,7 +44,8 @@ export const queryActionAuditLogsQuery = z.object({
.openapi({
type: "string",
format: "date-time",
description: "End time as ISO date string (defaults to current time)"
description:
"End time as ISO date string (defaults to current time)"
}),
action: z.string().optional(),
actorType: z.string().optional(),
@@ -68,8 +69,9 @@ export const queryActionAuditLogsParams = z.object({
orgId: z.string()
});
export const queryActionAuditLogsCombined =
queryActionAuditLogsQuery.merge(queryActionAuditLogsParams);
export const queryActionAuditLogsCombined = queryActionAuditLogsQuery.merge(
queryActionAuditLogsParams
);
type Q = z.infer<typeof queryActionAuditLogsCombined>;
function getWhere(data: Q) {
@@ -78,7 +80,9 @@ function getWhere(data: Q) {
lt(actionAuditLog.timestamp, data.timeEnd),
eq(actionAuditLog.orgId, data.orgId),
data.actor ? eq(actionAuditLog.actor, data.actor) : undefined,
data.actorType ? eq(actionAuditLog.actorType, data.actorType) : undefined,
data.actorType
? eq(actionAuditLog.actorType, data.actorType)
: undefined,
data.actorId ? eq(actionAuditLog.actorId, data.actorId) : undefined,
data.action ? eq(actionAuditLog.action, data.action) : undefined
);
@@ -135,8 +139,12 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
return {
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
actions: uniqueActions.map(row => row.action).filter((action): action is string => action !== null),
actors: uniqueActors
.map((row) => row.actor)
.filter((actor): actor is string => actor !== null),
actions: uniqueActions
.map((row) => row.action)
.filter((action): action is string => action !== null)
};
}

View File

@@ -13,4 +13,4 @@
export * from "./transferSession";
export * from "./getSessionTransferToken";
export * from "./quickStart";
export * from "./quickStart";

View File

@@ -395,7 +395,8 @@ export async function quickStart(
.values({
targetId: newTarget[0].targetId,
hcEnabled: false
}).returning();
})
.returning();
// add the new target to the targetIps array
targetIps.push(`${ip}/32`);
@@ -406,7 +407,12 @@ export async function quickStart(
.where(eq(newts.siteId, siteId!))
.limit(1);
await addTargets(newt.newtId, newTarget, newHealthcheck, resource.protocol);
await addTargets(
newt.newtId,
newTarget,
newHealthcheck,
resource.protocol
);
// Set resource pincode if provided
if (pincode) {

View File

@@ -26,8 +26,8 @@ import { getLineItems, getStandardFeaturePriceSet } from "@server/lib/billing";
import { getTierPriceSet, TierId } from "@server/lib/billing/tiers";
const createCheckoutSessionSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function createCheckoutSession(
req: Request,
@@ -72,7 +72,7 @@ export async function createCheckoutSession(
billing_address_collection: "required",
line_items: [
{
price: standardTierPrice, // Use the standard tier
price: standardTierPrice, // Use the standard tier
quantity: 1
},
...getLineItems(getStandardFeaturePriceSet())

View File

@@ -24,8 +24,8 @@ import { fromError } from "zod-validation-error";
import stripe from "#private/lib/stripe";
const createPortalSessionSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function createPortalSession(
req: Request,

View File

@@ -34,8 +34,8 @@ import {
} from "@server/db";
const getOrgSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
registry.registerPath({
method: "get",

View File

@@ -28,8 +28,8 @@ import { FeatureId } from "@server/lib/billing";
import { GetOrgUsageResponse } from "@server/routers/billing/types";
const getOrgSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
registry.registerPath({
method: "get",
@@ -78,11 +78,23 @@ export async function getOrgUsage(
// Get usage for org
const usageData = [];
const siteUptime = await usageService.getUsage(orgId, FeatureId.SITE_UPTIME);
const siteUptime = await usageService.getUsage(
orgId,
FeatureId.SITE_UPTIME
);
const users = await usageService.getUsageDaily(orgId, FeatureId.USERS);
const domains = await usageService.getUsageDaily(orgId, FeatureId.DOMAINS);
const remoteExitNodes = await usageService.getUsageDaily(orgId, FeatureId.REMOTE_EXIT_NODES);
const egressData = await usageService.getUsage(orgId, FeatureId.EGRESS_DATA_MB);
const domains = await usageService.getUsageDaily(
orgId,
FeatureId.DOMAINS
);
const remoteExitNodes = await usageService.getUsageDaily(
orgId,
FeatureId.REMOTE_EXIT_NODES
);
const egressData = await usageService.getUsage(
orgId,
FeatureId.EGRESS_DATA_MB
);
if (siteUptime) {
usageData.push(siteUptime);
@@ -100,7 +112,8 @@ export async function getOrgUsage(
usageData.push(remoteExitNodes);
}
const orgLimits = await db.select()
const orgLimits = await db
.select()
.from(limits)
.where(eq(limits.orgId, orgId));

View File

@@ -31,9 +31,7 @@ export async function handleCustomerDeleted(
return;
}
await db
.delete(customers)
.where(eq(customers.customerId, customer.id));
await db.delete(customers).where(eq(customers.customerId, customer.id));
} catch (error) {
logger.error(
`Error handling customer created event for ID ${customer.id}:`,

View File

@@ -12,7 +12,14 @@
*/
import Stripe from "stripe";
import { subscriptions, db, subscriptionItems, customers, userOrgs, users } from "@server/db";
import {
subscriptions,
db,
subscriptionItems,
customers,
userOrgs,
users
} from "@server/db";
import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
@@ -43,7 +50,6 @@ export async function handleSubscriptionDeleted(
.delete(subscriptionItems)
.where(eq(subscriptionItems.subscriptionId, subscription.id));
// Lookup customer to get orgId
const [customer] = await db
.select()
@@ -58,10 +64,7 @@ export async function handleSubscriptionDeleted(
return;
}
await handleSubscriptionLifesycle(
customer.orgId,
subscription.status
);
await handleSubscriptionLifesycle(customer.orgId, subscription.status);
const [orgUserRes] = await db
.select()

View File

@@ -15,4 +15,4 @@ export * from "./createCheckoutSession";
export * from "./createPortalSession";
export * from "./getOrgSubscription";
export * from "./getOrgUsage";
export * from "./internalGetOrgTier";
export * from "./internalGetOrgTier";

View File

@@ -22,8 +22,8 @@ import { getOrgTierData } from "#private/lib/billing";
import { GetOrgTierResponse } from "@server/routers/billing/types";
const getOrgSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function getOrgTier(
req: Request,

View File

@@ -11,11 +11,18 @@
* This file is not licensed under the AGPLv3.
*/
import { freeLimitSet, limitsService, subscribedLimitSet } from "@server/lib/billing";
import {
freeLimitSet,
limitsService,
subscribedLimitSet
} from "@server/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import logger from "@server/logger";
export async function handleSubscriptionLifesycle(orgId: string, status: string) {
export async function handleSubscriptionLifesycle(
orgId: string,
status: string
) {
switch (status) {
case "active":
await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet);
@@ -42,4 +49,4 @@ export async function handleSubscriptionLifesycle(orgId: string, status: string)
default:
break;
}
}
}

View File

@@ -32,12 +32,13 @@ export async function billingWebhookHandler(
next: NextFunction
): Promise<any> {
let event: Stripe.Event = req.body;
const endpointSecret = privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
const endpointSecret =
privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
if (!endpointSecret) {
logger.warn("Stripe webhook secret is not configured. Webhook events will not be priocessed.");
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "")
logger.warn(
"Stripe webhook secret is not configured. Webhook events will not be priocessed."
);
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, ""));
}
// Only verify the event if you have an endpoint secret defined.
@@ -49,7 +50,10 @@ export async function billingWebhookHandler(
if (!signature) {
logger.info("No stripe signature found in headers.");
return next(
createHttpError(HttpCode.BAD_REQUEST, "No stripe signature found in headers")
createHttpError(
HttpCode.BAD_REQUEST,
"No stripe signature found in headers"
)
);
}
@@ -62,7 +66,10 @@ export async function billingWebhookHandler(
} catch (err) {
logger.error(`Webhook signature verification failed.`, err);
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Webhook signature verification failed")
createHttpError(
HttpCode.UNAUTHORIZED,
"Webhook signature verification failed"
)
);
}
}

View File

@@ -24,10 +24,10 @@ import { registry } from "@server/openApi";
import { GetCertificateResponse } from "@server/routers/certificates/types";
const getCertificateSchema = z.strictObject({
domainId: z.string(),
domain: z.string().min(1).max(255),
orgId: z.string()
});
domainId: z.string(),
domain: z.string().min(1).max(255),
orgId: z.string()
});
async function query(domainId: string, domain: string) {
const [domainRecord] = await db
@@ -42,8 +42,8 @@ async function query(domainId: string, domain: string) {
let existing: any[] = [];
if (domainRecord.type == "ns") {
const domainLevelDown = domain.split('.').slice(1).join('.');
const domainLevelDown = domain.split(".").slice(1).join(".");
existing = await db
.select({
certId: certificates.certId,
@@ -64,7 +64,7 @@ async function query(domainId: string, domain: string) {
eq(certificates.wildcard, true), // only NS domains can have wildcard certs
or(
eq(certificates.domain, domain),
eq(certificates.domain, domainLevelDown),
eq(certificates.domain, domainLevelDown)
)
)
);
@@ -102,8 +102,7 @@ registry.registerPath({
tags: ["Certificate"],
request: {
params: z.object({
domainId: z
.string(),
domainId: z.string(),
domain: z.string().min(1).max(255),
orgId: z.string()
})
@@ -133,7 +132,9 @@ export async function getCertificate(
if (!cert) {
logger.warn(`Certificate not found for domain: ${domainId}`);
return next(createHttpError(HttpCode.NOT_FOUND, "Certificate not found"));
return next(
createHttpError(HttpCode.NOT_FOUND, "Certificate not found")
);
}
return response<GetCertificateResponse>(res, {

View File

@@ -12,4 +12,4 @@
*/
export * from "./getCertificate";
export * from "./restartCertificate";
export * from "./restartCertificate";

View File

@@ -25,9 +25,9 @@ import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const restartCertificateParamsSchema = z.strictObject({
certId: z.string().transform(stoi).pipe(z.int().positive()),
orgId: z.string()
});
certId: z.string().transform(stoi).pipe(z.int().positive()),
orgId: z.string()
});
registry.registerPath({
method: "post",
@@ -36,10 +36,7 @@ registry.registerPath({
tags: ["Certificate"],
request: {
params: z.object({
certId: z
.string()
.transform(stoi)
.pipe(z.int().positive()),
certId: z.string().transform(stoi).pipe(z.int().positive()),
orgId: z.string()
})
},
@@ -94,7 +91,7 @@ export async function restartCertificate(
.set({
status: "pending",
errorMessage: null,
lastRenewalAttempt: Math.floor(Date.now() / 1000)
lastRenewalAttempt: Math.floor(Date.now() / 1000)
})
.where(eq(certificates.certId, certId));

View File

@@ -26,8 +26,8 @@ import { CheckDomainAvailabilityResponse } from "@server/routers/domain/types";
const paramsSchema = z.strictObject({});
const querySchema = z.strictObject({
subdomain: z.string()
});
subdomain: z.string()
});
registry.registerPath({
method: "get",

View File

@@ -12,4 +12,4 @@
*/
export * from "./checkDomainNamespaceAvailability";
export * from "./listDomainNamespaces";
export * from "./listDomainNamespaces";

View File

@@ -26,19 +26,19 @@ import { OpenAPITags, registry } from "@server/openApi";
const paramsSchema = z.strictObject({});
const querySchema = z.strictObject({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
async function query(limit: number, offset: number) {
const res = await db

View File

@@ -79,86 +79,72 @@ import semver from "semver";
// Zod schemas for request validation
const getResourceByDomainParamsSchema = z.strictObject({
domain: z.string().min(1, "Domain is required")
});
domain: z.string().min(1, "Domain is required")
});
const getUserSessionParamsSchema = z.strictObject({
userSessionId: z.string().min(1, "User session ID is required")
});
userSessionId: z.string().min(1, "User session ID is required")
});
const getUserOrgRoleParamsSchema = z.strictObject({
userId: z.string().min(1, "User ID is required"),
orgId: z.string().min(1, "Organization ID is required")
});
userId: z.string().min(1, "User ID is required"),
orgId: z.string().min(1, "Organization ID is required")
});
const getRoleResourceAccessParamsSchema = z.strictObject({
roleId: z
.string()
.transform(Number)
.pipe(
z.int().positive("Role ID must be a positive integer")
),
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
roleId: z
.string()
.transform(Number)
.pipe(z.int().positive("Role ID must be a positive integer")),
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const getUserResourceAccessParamsSchema = z.strictObject({
userId: z.string().min(1, "User ID is required"),
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
userId: z.string().min(1, "User ID is required"),
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const getResourceRulesParamsSchema = z.strictObject({
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const validateResourceSessionTokenParamsSchema = z.strictObject({
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const validateResourceSessionTokenBodySchema = z.strictObject({
token: z.string().min(1, "Token is required")
});
token: z.string().min(1, "Token is required")
});
const validateResourceAccessTokenBodySchema = z.strictObject({
accessTokenId: z.string().optional(),
resourceId: z.number().optional(),
accessToken: z.string()
});
accessTokenId: z.string().optional(),
resourceId: z.number().optional(),
accessToken: z.string()
});
// Certificates by domains query validation
const getCertificatesByDomainsQuerySchema = z.strictObject({
// Accept domains as string or array (domains or domains[])
domains: z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional(),
// Handle array format from query parameters (domains[])
"domains[]": z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional()
});
// Accept domains as string or array (domains or domains[])
domains: z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional(),
// Handle array format from query parameters (domains[])
"domains[]": z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional()
});
// Type exports for request schemas
export type GetResourceByDomainParams = z.infer<
@@ -566,8 +552,8 @@ hybridRouter.get(
);
const getOrgLoginPageParamsSchema = z.strictObject({
orgId: z.string().min(1)
});
orgId: z.string().min(1)
});
hybridRouter.get(
"/org/:orgId/login-page",
@@ -1408,8 +1394,16 @@ hybridRouter.post(
);
}
const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } =
parsedParams.data;
const {
olmId,
newtId,
ip,
port,
timestamp,
token,
publicKey,
reachableAt
} = parsedParams.data;
const destinations = await updateAndGenerateEndpointDestinations(
olmId,

View File

@@ -18,7 +18,7 @@ import * as logs from "#private/routers/auditLogs";
import {
verifyApiKeyHasAction,
verifyApiKeyIsRoot,
verifyApiKeyOrgAccess,
verifyApiKeyOrgAccess
} from "@server/middlewares";
import {
verifyValidSubscription,
@@ -26,7 +26,10 @@ import {
} from "#private/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import { unauthenticated as ua, authenticated as a } from "@server/routers/integration";
import {
unauthenticated as ua,
authenticated as a
} from "@server/routers/integration";
import { logActionAudit } from "#private/middlewares";
export const unauthenticated = ua;
@@ -37,7 +40,7 @@ authenticated.post(
verifyApiKeyIsRoot, // We are the only ones who can use root key so its fine
verifyApiKeyHasAction(ActionsEnum.sendUsageNotification),
logActionAudit(ActionsEnum.sendUsageNotification),
org.sendUsageNotification,
org.sendUsageNotification
);
authenticated.delete(
@@ -45,7 +48,7 @@ authenticated.delete(
verifyApiKeyIsRoot,
verifyApiKeyHasAction(ActionsEnum.deleteIdp),
logActionAudit(ActionsEnum.deleteIdp),
orgIdp.deleteOrgIdp,
orgIdp.deleteOrgIdp
);
authenticated.get(

View File

@@ -21,8 +21,8 @@ import { z } from "zod";
import { fromError } from "zod-validation-error";
const bodySchema = z.strictObject({
licenseKey: z.string().min(1).max(255)
});
licenseKey: z.string().min(1).max(255)
});
export async function activateLicense(
req: Request,

View File

@@ -24,8 +24,8 @@ import { licenseKey } from "@server/db";
import license from "#private/license/license";
const paramsSchema = z.strictObject({
licenseKey: z.string().min(1).max(255)
});
licenseKey: z.string().min(1).max(255)
});
export async function deleteLicenseKey(
req: Request,

View File

@@ -36,13 +36,13 @@ import { build } from "@server/build";
import { CreateLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
const bodySchema = z.strictObject({
subdomain: z.string().nullable().optional(),
domainId: z.string()
});
subdomain: z.string().nullable().optional(),
domainId: z.string()
});
export type CreateLoginPageBody = z.infer<typeof bodySchema>;
@@ -149,12 +149,20 @@ export async function createLoginPage(
let returned: LoginPage | undefined;
await db.transaction(async (trx) => {
const orgSites = await trx
.select()
.from(sites)
.innerJoin(exitNodes, eq(exitNodes.exitNodeId, sites.exitNodeId))
.where(and(eq(sites.orgId, orgId), eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
.innerJoin(
exitNodes,
eq(exitNodes.exitNodeId, sites.exitNodeId)
)
.where(
and(
eq(sites.orgId, orgId),
eq(exitNodes.type, "gerbil"),
eq(exitNodes.online, true)
)
)
.limit(10);
let exitNodesList = orgSites.map((s) => s.exitNodes);
@@ -163,7 +171,12 @@ export async function createLoginPage(
exitNodesList = await trx
.select()
.from(exitNodes)
.where(and(eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
.where(
and(
eq(exitNodes.type, "gerbil"),
eq(exitNodes.online, true)
)
)
.limit(10);
}

View File

@@ -78,15 +78,11 @@ export async function deleteLoginPage(
// if (!leftoverLinks.length) {
await db
.delete(loginPage)
.where(
eq(loginPage.loginPageId, parsedParams.data.loginPageId)
);
.where(eq(loginPage.loginPageId, parsedParams.data.loginPageId));
await db
.delete(loginPageOrg)
.where(
eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId)
);
.where(eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId));
// }
return response<LoginPage>(res, {

View File

@@ -23,8 +23,8 @@ import { fromError } from "zod-validation-error";
import { GetLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
async function query(orgId: string) {
const [res] = await db

View File

@@ -35,7 +35,8 @@ const paramsSchema = z
})
.strict();
const bodySchema = z.strictObject({
const bodySchema = z
.strictObject({
subdomain: subdomainSchema.nullable().optional(),
domainId: z.string().optional()
})
@@ -86,7 +87,7 @@ export async function updateLoginPage(
const { loginPageId, orgId } = parsedParams.data;
if (build === "saas"){
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
@@ -182,7 +183,10 @@ export async function updateLoginPage(
}
// update the full domain if it has changed
if (fullDomain && fullDomain !== existingLoginPage?.fullDomain) {
if (
fullDomain &&
fullDomain !== existingLoginPage?.fullDomain
) {
await db
.update(loginPage)
.set({ fullDomain })

View File

@@ -23,9 +23,9 @@ import SupportEmail from "@server/emails/templates/SupportEmail";
import config from "@server/lib/config";
const bodySchema = z.strictObject({
body: z.string().min(1),
subject: z.string().min(1).max(255)
});
body: z.string().min(1),
subject: z.string().min(1).max(255)
});
export async function sendSupportEmail(
req: Request,

View File

@@ -11,4 +11,4 @@
* This file is not licensed under the AGPLv3.
*/
export * from "./sendUsageNotifications";
export * from "./sendUsageNotifications";

View File

@@ -35,10 +35,12 @@ const sendUsageNotificationBodySchema = z.object({
notificationType: z.enum(["approaching_70", "approaching_90", "reached"]),
limitName: z.string(),
currentUsage: z.number(),
usageLimit: z.number(),
usageLimit: z.number()
});
type SendUsageNotificationRequest = z.infer<typeof sendUsageNotificationBodySchema>;
type SendUsageNotificationRequest = z.infer<
typeof sendUsageNotificationBodySchema
>;
export type SendUsageNotificationResponse = {
success: boolean;
@@ -97,17 +99,13 @@ async function getOrgAdmins(orgId: string) {
.where(
and(
eq(userOrgs.orgId, orgId),
or(
eq(userOrgs.isOwner, true),
eq(roles.isAdmin, true)
)
or(eq(userOrgs.isOwner, true), eq(roles.isAdmin, true))
)
);
// Filter to only include users with verified emails
const orgAdmins = admins.filter(admin =>
admin.email &&
admin.email.length > 0
const orgAdmins = admins.filter(
(admin) => admin.email && admin.email.length > 0
);
return orgAdmins;
@@ -119,7 +117,9 @@ export async function sendUsageNotification(
next: NextFunction
): Promise<any> {
try {
const parsedParams = sendUsageNotificationParamsSchema.safeParse(req.params);
const parsedParams = sendUsageNotificationParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
@@ -140,12 +140,8 @@ export async function sendUsageNotification(
}
const { orgId } = parsedParams.data;
const {
notificationType,
limitName,
currentUsage,
usageLimit,
} = parsedBody.data;
const { notificationType, limitName, currentUsage, usageLimit } =
parsedBody.data;
// Verify organization exists
const org = await db
@@ -192,7 +188,10 @@ export async function sendUsageNotification(
let template;
let subject;
if (notificationType === "approaching_70" || notificationType === "approaching_90") {
if (
notificationType === "approaching_70" ||
notificationType === "approaching_90"
) {
template = NotifyUsageLimitApproaching({
email: admin.email,
limitName,
@@ -220,10 +219,15 @@ export async function sendUsageNotification(
emailsSent++;
adminEmails.push(admin.email);
logger.info(`Usage notification sent to admin ${admin.email} for org ${orgId}`);
logger.info(
`Usage notification sent to admin ${admin.email} for org ${orgId}`
);
} catch (emailError) {
logger.error(`Failed to send usage notification to ${admin.email}:`, emailError);
logger.error(
`Failed to send usage notification to ${admin.email}:`,
emailError
);
// Continue with other admins even if one fails
}
}
@@ -239,11 +243,13 @@ export async function sendUsageNotification(
message: `Usage notifications sent to ${emailsSent} administrators`,
status: HttpCode.OK
});
} catch (error) {
logger.error("Error sending usage notifications:", error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to send usage notifications")
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to send usage notifications"
)
);
}
}

View File

@@ -32,19 +32,19 @@ import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types";
const paramsSchema = z.strictObject({ orgId: z.string().nonempty() });
const bodySchema = z.strictObject({
name: z.string().nonempty(),
clientId: z.string().nonempty(),
clientSecret: z.string().nonempty(),
authUrl: z.url(),
tokenUrl: z.url(),
identifierPath: z.string().nonempty(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().nonempty(),
autoProvision: z.boolean().optional(),
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
roleMapping: z.string().optional()
});
name: z.string().nonempty(),
clientId: z.string().nonempty(),
clientSecret: z.string().nonempty(),
authUrl: z.url(),
tokenUrl: z.url(),
identifierPath: z.string().nonempty(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().nonempty(),
autoProvision: z.boolean().optional(),
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
roleMapping: z.string().optional()
});
// registry.registerPath({
// method: "put",
@@ -158,7 +158,10 @@ export async function createOrgOidcIdp(
});
});
const redirectUrl = await generateOidcRedirectUrl(idpId as number, orgId);
const redirectUrl = await generateOidcRedirectUrl(
idpId as number,
orgId
);
return response<CreateOrgIdpResponse>(res, {
data: {

View File

@@ -66,12 +66,7 @@ export async function deleteOrgIdp(
.where(eq(idp.idpId, idpId));
if (!existingIdp) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"IdP not found"
)
);
return next(createHttpError(HttpCode.NOT_FOUND, "IdP not found"));
}
// Delete the IDP and its related records in a transaction
@@ -82,14 +77,10 @@ export async function deleteOrgIdp(
.where(eq(idpOidcConfig.idpId, idpId));
// Delete IDP-org mappings
await trx
.delete(idpOrg)
.where(eq(idpOrg.idpId, idpId));
await trx.delete(idpOrg).where(eq(idpOrg.idpId, idpId));
// Delete the IDP itself
await trx
.delete(idp)
.where(eq(idp.idpId, idpId));
await trx.delete(idp).where(eq(idp.idpId, idpId));
});
return response<null>(res, {

View File

@@ -93,7 +93,10 @@ export async function getOrgIdp(
idpRes.idpOidcConfig!.clientId = decrypt(clientId, key);
}
const redirectUrl = await generateOidcRedirectUrl(idpRes.idp.idpId, orgId);
const redirectUrl = await generateOidcRedirectUrl(
idpRes.idp.idpId,
orgId
);
return response<GetOrgIdpResponse>(res, {
data: {

View File

@@ -15,4 +15,4 @@ export * from "./createOrgOidcIdp";
export * from "./getOrgIdp";
export * from "./listOrgIdps";
export * from "./updateOrgOidcIdp";
export * from "./deleteOrgIdp";
export * from "./deleteOrgIdp";

View File

@@ -25,23 +25,23 @@ import { OpenAPITags, registry } from "@server/openApi";
import { ListOrgIdpsResponse } from "@server/routers/orgIdp/types";
const querySchema = z.strictObject({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
const paramsSchema = z.strictObject({
orgId: z.string().nonempty()
});
orgId: z.string().nonempty()
});
async function query(orgId: string, limit: number, offset: number) {
const res = await db

View File

@@ -36,18 +36,18 @@ const paramsSchema = z
.strict();
const bodySchema = z.strictObject({
name: z.string().optional(),
clientId: z.string().optional(),
clientSecret: z.string().optional(),
authUrl: z.string().optional(),
tokenUrl: z.string().optional(),
identifierPath: z.string().optional(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().optional(),
autoProvision: z.boolean().optional(),
roleMapping: z.string().optional()
});
name: z.string().optional(),
clientId: z.string().optional(),
clientSecret: z.string().optional(),
authUrl: z.string().optional(),
tokenUrl: z.string().optional(),
identifierPath: z.string().optional(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().optional(),
autoProvision: z.boolean().optional(),
roleMapping: z.string().optional()
});
export type UpdateOrgIdpResponse = {
idpId: number;

View File

@@ -13,4 +13,4 @@
export * from "./reGenerateClientSecret";
export * from "./reGenerateSiteSecret";
export * from "./reGenerateExitNodeSecret";
export * from "./reGenerateExitNodeSecret";

View File

@@ -123,7 +123,10 @@ export async function reGenerateClientSecret(
};
// Don't await this to prevent blocking the response
sendToClient(existingOlms[0].olmId, payload).catch((error) => {
logger.error("Failed to send termination message to olm:", error);
logger.error(
"Failed to send termination message to olm:",
error
);
});
disconnectClient(existingOlms[0].olmId).catch((error) => {
@@ -133,7 +136,7 @@ export async function reGenerateClientSecret(
return response(res, {
data: {
olmId: existingOlms[0].olmId,
olmId: existingOlms[0].olmId
},
success: true,
error: false,

View File

@@ -12,7 +12,14 @@
*/
import { NextFunction, Request, Response } from "express";
import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg, RemoteExitNode } from "@server/db";
import {
db,
exitNodes,
exitNodeOrgs,
ExitNode,
ExitNodeOrg,
RemoteExitNode
} from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { remoteExitNodes } from "@server/db";
@@ -91,14 +98,15 @@ export async function reGenerateExitNodeSecret(
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(existingRemoteExitNode.remoteExitNodeId, payload).catch(
(error) => {
logger.error(
"Failed to send termination message to remote exit node:",
error
);
}
);
sendToClient(
existingRemoteExitNode.remoteExitNodeId,
payload
).catch((error) => {
logger.error(
"Failed to send termination message to remote exit node:",
error
);
});
disconnectClient(existingRemoteExitNode.remoteExitNodeId).catch(
(error) => {

View File

@@ -80,7 +80,7 @@ export async function reGenerateSiteSecret(
const secretHash = await hashPassword(secret);
// get the newt to verify it exists
const existingNewts = await db
const existingNewts = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId));
@@ -120,15 +120,20 @@ export async function reGenerateSiteSecret(
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(existingNewts[0].newtId, payload).catch((error) => {
logger.error(
"Failed to send termination message to newt:",
error
);
});
sendToClient(existingNewts[0].newtId, payload).catch(
(error) => {
logger.error(
"Failed to send termination message to newt:",
error
);
}
);
disconnectClient(existingNewts[0].newtId).catch((error) => {
logger.error("Failed to disconnect newt after re-key:", error);
logger.error(
"Failed to disconnect newt after re-key:",
error
);
});
}

View File

@@ -36,9 +36,9 @@ export const paramsSchema = z.object({
});
const bodySchema = z.strictObject({
remoteExitNodeId: z.string().length(15),
secret: z.string().length(48)
});
remoteExitNodeId: z.string().length(15),
secret: z.string().length(48)
});
export type CreateRemoteExitNodeBody = z.infer<typeof bodySchema>;

View File

@@ -25,9 +25,9 @@ import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
const paramsSchema = z.strictObject({
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
export async function deleteRemoteExitNode(
req: Request,

View File

@@ -24,9 +24,9 @@ import { fromError } from "zod-validation-error";
import { GetRemoteExitNodeResponse } from "@server/routers/remoteExitNode/types";
const getRemoteExitNodeSchema = z.strictObject({
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
async function query(remoteExitNodeId: string) {
const [remoteExitNode] = await db

View File

@@ -55,7 +55,8 @@ export async function getRemoteExitNodeToken(
try {
if (token) {
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
const { session, remoteExitNode } =
await validateRemoteExitNodeSessionToken(token);
if (session) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
@@ -103,7 +104,10 @@ export async function getRemoteExitNodeToken(
}
const resToken = generateSessionToken();
await createRemoteExitNodeSession(resToken, existingRemoteExitNode.remoteExitNodeId);
await createRemoteExitNodeSession(
resToken,
existingRemoteExitNode.remoteExitNodeId
);
// logger.debug(`Created RemoteExitNode token response: ${JSON.stringify(resToken)}`);

View File

@@ -33,7 +33,9 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = Math.floor((Date.now() - OFFLINE_THRESHOLD_MS) / 1000);
const twoMinutesAgo = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
);
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
const newlyOfflineNodes = await db
@@ -48,11 +50,13 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
isNull(exitNodes.lastPing)
)
)
).returning();
)
.returning();
// Update the sites to offline if they have not pinged either
const exitNodeIds = newlyOfflineNodes.map(node => node.exitNodeId);
const exitNodeIds = newlyOfflineNodes.map(
(node) => node.exitNodeId
);
const sitesOnNode = await db
.select()
@@ -77,7 +81,6 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
.where(eq(sites.siteId, site.siteId));
}
}
} catch (error) {
logger.error("Error in offline checker interval", { error });
}
@@ -100,7 +103,9 @@ export const stopRemoteExitNodeOfflineChecker = (): void => {
/**
* Handles ping messages from clients and responds with pong
*/
export const handleRemoteExitNodePingMessage: MessageHandler = async (context) => {
export const handleRemoteExitNodePingMessage: MessageHandler = async (
context
) => {
const { message, client: c, sendToClient } = context;
const remoteExitNode = c as RemoteExitNode;
@@ -120,7 +125,7 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
.update(exitNodes)
.set({
lastPing: Math.floor(Date.now() / 1000),
online: true,
online: true
})
.where(eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId));
} catch (error) {
@@ -131,7 +136,7 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
message: {
type: "pong",
data: {
timestamp: new Date().toISOString(),
timestamp: new Date().toISOString()
}
},
broadcast: false,

View File

@@ -29,7 +29,8 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
return;
}
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } = message.data;
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } =
message.data;
if (!remoteExitNodeVersion) {
logger.warn("Remote exit node version not found");
@@ -39,7 +40,10 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
// update the version
await db
.update(remoteExitNodes)
.set({ version: remoteExitNodeVersion, secondaryVersion: remoteExitNodeSecondaryVersion })
.set({
version: remoteExitNodeVersion,
secondaryVersion: remoteExitNodeSecondaryVersion
})
.where(
eq(
remoteExitNodes.remoteExitNodeId,

View File

@@ -24,8 +24,8 @@ import { fromError } from "zod-validation-error";
import { ListRemoteExitNodesResponse } from "@server/routers/remoteExitNode/types";
const listRemoteExitNodesParamsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
const listRemoteExitNodesSchema = z.object({
limit: z

View File

@@ -22,8 +22,8 @@ import { z } from "zod";
import { PickRemoteExitNodeDefaultsResponse } from "@server/routers/remoteExitNode/types";
const paramsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function pickRemoteExitNodeDefaults(
req: Request,

View File

@@ -38,7 +38,9 @@ export async function quickStartRemoteExitNode(
next: NextFunction
): Promise<any> {
try {
const parsedBody = quickStartRemoteExitNodeBodySchema.safeParse(req.body);
const parsedBody = quickStartRemoteExitNodeBodySchema.safeParse(
req.body
);
if (!parsedBody.success) {
return next(
createHttpError(

View File

@@ -11,4 +11,4 @@
* This file is not licensed under the AGPLv3.
*/
export * from "./ws";
export * from "./ws";

View File

@@ -23,4 +23,4 @@ export const messageHandlers: Record<string, MessageHandler> = {
"remoteExitNode/ping": handleRemoteExitNodePingMessage
};
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes

View File

@@ -37,7 +37,14 @@ import { validateRemoteExitNodeSessionToken } from "#private/auth/sessions/remot
import { rateLimitService } from "#private/lib/rateLimit";
import { messageHandlers } from "@server/routers/ws/messageHandlers";
import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers";
import { AuthenticatedWebSocket, ClientType, WSMessage, TokenPayload, WebSocketRequest, RedisMessage } from "@server/routers/ws";
import {
AuthenticatedWebSocket,
ClientType,
WSMessage,
TokenPayload,
WebSocketRequest,
RedisMessage
} from "@server/routers/ws";
import { validateSessionToken } from "@server/auth/sessions/app";
// Merge public and private message handlers
@@ -216,7 +223,7 @@ const initializeRedisSubscription = async (): Promise<void> => {
// Each node is responsible for restoring its own connection state to Redis
// This approach is more efficient than cross-node coordination because:
// 1. Each node knows its own connections (source of truth)
// 2. No network overhead from broadcasting state between nodes
// 2. No network overhead from broadcasting state between nodes
// 3. No race conditions from simultaneous updates
// 4. Redis becomes eventually consistent as each node restores independently
// 5. Simpler logic with better fault tolerance
@@ -233,8 +240,10 @@ const recoverConnectionState = async (): Promise<void> => {
// Each node simply restores its own local connections to Redis
// This is the source of truth - no need for cross-node coordination
await restoreLocalConnectionsToRedis();
logger.info("Redis connection state recovery completed - restored local state");
logger.info(
"Redis connection state recovery completed - restored local state"
);
} catch (error) {
logger.error("Error during Redis recovery:", error);
} finally {
@@ -251,8 +260,10 @@ const restoreLocalConnectionsToRedis = async (): Promise<void> => {
try {
// Restore all current local connections to Redis
for (const [clientId, clients] of connectedClients.entries()) {
const validClients = clients.filter(client => client.readyState === WebSocket.OPEN);
const validClients = clients.filter(
(client) => client.readyState === WebSocket.OPEN
);
if (validClients.length > 0) {
// Add this node to the client's connection list
await redisManager.sadd(getConnectionsKey(clientId), NODE_ID);
@@ -303,7 +314,10 @@ const addClient = async (
Date.now().toString()
);
} catch (error) {
logger.error("Failed to add client to Redis tracking (connection still functional locally):", error);
logger.error(
"Failed to add client to Redis tracking (connection still functional locally):",
error
);
}
}
@@ -326,9 +340,14 @@ const removeClient = async (
if (redisManager.isRedisEnabled()) {
try {
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
await redisManager.del(getNodeConnectionsKey(NODE_ID, clientId));
await redisManager.del(
getNodeConnectionsKey(NODE_ID, clientId)
);
} catch (error) {
logger.error("Failed to remove client from Redis tracking (cleanup will occur on recovery):", error);
logger.error(
"Failed to remove client from Redis tracking (cleanup will occur on recovery):",
error
);
}
}
@@ -345,7 +364,10 @@ const removeClient = async (
ws.connectionId
);
} catch (error) {
logger.error("Failed to remove specific connection from Redis tracking:", error);
logger.error(
"Failed to remove specific connection from Redis tracking:",
error
);
}
}
@@ -372,7 +394,9 @@ const sendToClientLocal = async (
}
});
logger.debug(`sendToClient: Message type ${message.type} sent to clientId ${clientId}`);
logger.debug(
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
);
return true;
};
@@ -411,14 +435,22 @@ const sendToClient = async (
fromNodeId: NODE_ID
};
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
} catch (error) {
logger.error("Failed to send message via Redis, message may be lost:", error);
logger.error(
"Failed to send message via Redis, message may be lost:",
error
);
// Continue execution - local delivery already attempted
}
} else if (!localSent && !redisManager.isRedisEnabled()) {
// Redis is disabled or unavailable - log that we couldn't deliver to remote nodes
logger.debug(`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`);
logger.debug(
`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`
);
}
return localSent;
@@ -441,13 +473,21 @@ const broadcastToAllExcept = async (
fromNodeId: NODE_ID
};
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
} catch (error) {
logger.error("Failed to broadcast message via Redis, remote nodes may not receive it:", error);
logger.error(
"Failed to broadcast message via Redis, remote nodes may not receive it:",
error
);
// Continue execution - local broadcast already completed
}
} else {
logger.debug("Redis unavailable - broadcast limited to local node only");
logger.debug(
"Redis unavailable - broadcast limited to local node only"
);
}
};
@@ -512,8 +552,10 @@ const verifyToken = async (
return null;
}
if (olm.userId) { // this is a user device and we need to check the user token
const { session: userSession, user } = await validateSessionToken(userToken);
if (olm.userId) {
// this is a user device and we need to check the user token
const { session: userSession, user } =
await validateSessionToken(userToken);
if (!userSession || !user) {
return null;
}
@@ -668,7 +710,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
url.searchParams.get("token") ||
request.headers["sec-websocket-protocol"] ||
"";
const userToken = url.searchParams.get('userToken') || '';
const userToken = url.searchParams.get("userToken") || "";
let clientType = url.searchParams.get(
"clientType"
) as ClientType;
@@ -690,7 +732,11 @@ const handleWSUpgrade = (server: HttpServer): void => {
return;
}
const tokenPayload = await verifyToken(token, clientType, userToken);
const tokenPayload = await verifyToken(
token,
clientType,
userToken
);
if (!tokenPayload) {
logger.debug(
"Unauthorized connection attempt: invalid token..."
@@ -724,50 +770,68 @@ const handleWSUpgrade = (server: HttpServer): void => {
// Add periodic connection state sync to handle Redis disconnections/reconnections
const startPeriodicStateSync = (): void => {
// Lightweight sync every 5 minutes - just restore our own state
setInterval(async () => {
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
try {
await restoreLocalConnectionsToRedis();
logger.debug("Periodic connection state sync completed");
} catch (error) {
logger.error("Error during periodic connection state sync:", error);
setInterval(
async () => {
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
try {
await restoreLocalConnectionsToRedis();
logger.debug("Periodic connection state sync completed");
} catch (error) {
logger.error(
"Error during periodic connection state sync:",
error
);
}
}
}
}, 5 * 60 * 1000); // 5 minutes
},
5 * 60 * 1000
); // 5 minutes
// Cleanup stale connections every 15 minutes
setInterval(async () => {
if (redisManager.isRedisEnabled()) {
try {
await cleanupStaleConnections();
logger.debug("Periodic connection cleanup completed");
} catch (error) {
logger.error("Error during periodic connection cleanup:", error);
setInterval(
async () => {
if (redisManager.isRedisEnabled()) {
try {
await cleanupStaleConnections();
logger.debug("Periodic connection cleanup completed");
} catch (error) {
logger.error(
"Error during periodic connection cleanup:",
error
);
}
}
}
}, 15 * 60 * 1000); // 15 minutes
},
15 * 60 * 1000
); // 15 minutes
};
const cleanupStaleConnections = async (): Promise<void> => {
if (!redisManager.isRedisEnabled()) return;
try {
const nodeKeys = await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`) || [];
const nodeKeys =
(await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`)) ||
[];
for (const nodeKey of nodeKeys) {
const connections = await redisManager.hgetall(nodeKey);
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, '');
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, "");
const localClients = connectedClients.get(clientId) || [];
const localConnectionIds = localClients
.filter(client => client.readyState === WebSocket.OPEN)
.map(client => client.connectionId)
.filter((client) => client.readyState === WebSocket.OPEN)
.map((client) => client.connectionId)
.filter(Boolean);
// Remove Redis entries for connections that no longer exist locally
for (const [connectionId, timestamp] of Object.entries(connections)) {
for (const [connectionId, timestamp] of Object.entries(
connections
)) {
if (!localConnectionIds.includes(connectionId)) {
await redisManager.hdel(nodeKey, connectionId);
logger.debug(`Cleaned up stale connection: ${connectionId} for client: ${clientId}`);
logger.debug(
`Cleaned up stale connection: ${connectionId} for client: ${clientId}`
);
}
}
@@ -776,7 +840,9 @@ const cleanupStaleConnections = async (): Promise<void> => {
if (Object.keys(remainingConnections).length === 0) {
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
await redisManager.del(nodeKey);
logger.debug(`Cleaned up empty connection tracking for client: ${clientId}`);
logger.debug(
`Cleaned up empty connection tracking for client: ${clientId}`
);
}
}
} catch (error) {
@@ -789,38 +855,38 @@ if (redisManager.isRedisEnabled()) {
initializeRedisSubscription().catch((error) => {
logger.error("Failed to initialize Redis subscription:", error);
});
// Register recovery callback with Redis manager
// When Redis reconnects, each node simply restores its own local state
redisManager.onReconnection(async () => {
logger.info("Redis reconnected, starting WebSocket state recovery...");
await recoverConnectionState();
});
// Start periodic state synchronization
startPeriodicStateSync();
logger.info(
`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`
);
} else {
logger.debug(
"WebSocket handler initialized in local mode"
);
logger.debug("WebSocket handler initialized in local mode");
}
// Disconnect a specific client and force them to reconnect
const disconnectClient = async (clientId: string): Promise<boolean> => {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
if (!clients || clients.length === 0) {
logger.debug(`No connections found for client ID: ${clientId}`);
return false;
}
logger.info(`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`);
logger.info(
`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`
);
// Close all connections for this client
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {