Merge branch 'dev' into multi-role

This commit is contained in:
miloschwartz
2026-03-24 22:01:13 -07:00
266 changed files with 7813 additions and 5880 deletions

View File

@@ -13,8 +13,12 @@
import { rateLimitService } from "#private/lib/rateLimit";
import { cleanup as wsCleanup } from "#private/routers/ws";
import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage";
import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth";
async function cleanup() {
await flushBandwidthToDb();
await flushSiteBandwidthToDb();
await rateLimitService.cleanup();
await wsCleanup();
@@ -25,4 +29,4 @@ export async function initCleanup() {
// Handle process termination
process.on("SIGTERM", () => cleanup());
process.on("SIGINT", () => cleanup());
}
}

287
server/private/lib/cache.ts Normal file
View File

@@ -0,0 +1,287 @@
import NodeCache from "node-cache";
import logger from "@server/logger";
import { redisManager } from "@server/private/lib/redis";
// Create local cache with maxKeys limit to prevent memory leaks
// With ~10k requests/day and 5min TTL, 10k keys should be more than sufficient
export const localCache = new NodeCache({
stdTTL: 3600,
checkperiod: 120,
maxKeys: 10000
});
// Log cache statistics periodically for monitoring
setInterval(() => {
const stats = localCache.getStats();
logger.debug(
`Local cache stats - Keys: ${stats.keys}, Hits: ${stats.hits}, Misses: ${stats.misses}, Hit rate: ${stats.hits > 0 ? ((stats.hits / (stats.hits + stats.misses)) * 100).toFixed(2) : 0}%`
);
}, 300000); // Every 5 minutes
/**
* Adaptive cache that uses Redis when available in multi-node environments,
* otherwise falls back to local memory cache for single-node deployments.
*/
class AdaptiveCache {
private useRedis(): boolean {
return (
redisManager.isRedisEnabled() &&
redisManager.getHealthStatus().isHealthy
);
}
/**
* Set a value in the cache
* @param key - Cache key
* @param value - Value to cache (will be JSON stringified for Redis)
* @param ttl - Time to live in seconds (0 = no expiration; omit = 3600s for Redis)
* @returns boolean indicating success
*/
async set(key: string, value: any, ttl?: number): Promise<boolean> {
const effectiveTtl = ttl === 0 ? undefined : ttl;
const redisTtl = ttl === 0 ? undefined : (ttl ?? 3600);
if (this.useRedis()) {
try {
const serialized = JSON.stringify(value);
const success = await redisManager.set(
key,
serialized,
redisTtl
);
if (success) {
logger.debug(`Set key in Redis: ${key}`);
return true;
}
// Redis failed, fall through to local cache
logger.debug(
`Redis set failed for key ${key}, falling back to local cache`
);
} catch (error) {
logger.error(`Redis set error for key ${key}:`, error);
// Fall through to local cache
}
}
// Use local cache as fallback or primary
const success = localCache.set(key, value, effectiveTtl || 0);
if (success) {
logger.debug(`Set key in local cache: ${key}`);
}
return success;
}
/**
* Get a value from the cache
* @param key - Cache key
* @returns The cached value or undefined if not found
*/
async get<T = any>(key: string): Promise<T | undefined> {
if (this.useRedis()) {
try {
const value = await redisManager.get(key);
if (value !== null) {
logger.debug(`Cache hit in Redis: ${key}`);
return JSON.parse(value) as T;
}
logger.debug(`Cache miss in Redis: ${key}`);
return undefined;
} catch (error) {
logger.error(`Redis get error for key ${key}:`, error);
// Fall through to local cache
}
}
// Use local cache as fallback or primary
const value = localCache.get<T>(key);
if (value !== undefined) {
logger.debug(`Cache hit in local cache: ${key}`);
} else {
logger.debug(`Cache miss in local cache: ${key}`);
}
return value;
}
/**
* Delete a value from the cache
* @param key - Cache key or array of keys
* @returns Number of deleted entries
*/
async del(key: string | string[]): Promise<number> {
const keys = Array.isArray(key) ? key : [key];
let deletedCount = 0;
if (this.useRedis()) {
try {
for (const k of keys) {
const success = await redisManager.del(k);
if (success) {
deletedCount++;
logger.debug(`Deleted key from Redis: ${k}`);
}
}
if (deletedCount === keys.length) {
return deletedCount;
}
// Some Redis deletes failed, fall through to local cache
logger.debug(
`Some Redis deletes failed, falling back to local cache`
);
} catch (error) {
logger.error(
`Redis del error for keys ${keys.join(", ")}:`,
error
);
// Fall through to local cache
deletedCount = 0;
}
}
// Use local cache as fallback or primary
for (const k of keys) {
const success = localCache.del(k);
if (success > 0) {
deletedCount++;
logger.debug(`Deleted key from local cache: ${k}`);
}
}
return deletedCount;
}
/**
* Check if a key exists in the cache
* @param key - Cache key
* @returns boolean indicating if key exists
*/
async has(key: string): Promise<boolean> {
if (this.useRedis()) {
try {
const value = await redisManager.get(key);
return value !== null;
} catch (error) {
logger.error(`Redis has error for key ${key}:`, error);
// Fall through to local cache
}
}
// Use local cache as fallback or primary
return localCache.has(key);
}
/**
* Get multiple values from the cache
* @param keys - Array of cache keys
* @returns Array of values (undefined for missing keys)
*/
async mget<T = any>(keys: string[]): Promise<(T | undefined)[]> {
if (this.useRedis()) {
try {
const results: (T | undefined)[] = [];
for (const key of keys) {
const value = await redisManager.get(key);
if (value !== null) {
results.push(JSON.parse(value) as T);
} else {
results.push(undefined);
}
}
return results;
} catch (error) {
logger.error(`Redis mget error:`, error);
// Fall through to local cache
}
}
// Use local cache as fallback or primary
return keys.map((key) => localCache.get<T>(key));
}
/**
* Flush all keys from the cache
*/
async flushAll(): Promise<void> {
if (this.useRedis()) {
logger.warn(
"Adaptive cache flushAll called - Redis flush not implemented, only local cache will be flushed"
);
}
localCache.flushAll();
logger.debug("Flushed local cache");
}
/**
* Get cache statistics
* Note: Only returns local cache stats, Redis stats are not included
*/
getStats() {
return localCache.getStats();
}
/**
* Get the current cache backend being used
* @returns "redis" if Redis is available and healthy, "local" otherwise
*/
getCurrentBackend(): "redis" | "local" {
return this.useRedis() ? "redis" : "local";
}
/**
* Take a key from the cache and delete it
* @param key - Cache key
* @returns The value or undefined if not found
*/
async take<T = any>(key: string): Promise<T | undefined> {
const value = await this.get<T>(key);
if (value !== undefined) {
await this.del(key);
}
return value;
}
/**
* Get TTL (time to live) for a key
* @param key - Cache key
* @returns TTL in seconds, 0 if no expiration, -1 if key doesn't exist
*/
getTtl(key: string): number {
// Note: This only works for local cache, Redis TTL is not supported
if (this.useRedis()) {
logger.warn(
`getTtl called for key ${key} but Redis TTL lookup is not implemented`
);
}
const ttl = localCache.getTtl(key);
if (ttl === undefined) {
return -1;
}
return Math.max(0, Math.floor((ttl - Date.now()) / 1000));
}
/**
* Get all keys from the cache
* Note: Only returns local cache keys, Redis keys are not included
*/
keys(): string[] {
if (this.useRedis()) {
logger.warn(
"keys() called but Redis keys are not included, only local cache keys returned"
);
}
return localCache.keys();
}
}
// Export singleton instance
export const cache = new AdaptiveCache();
export default cache;

View File

@@ -15,9 +15,8 @@ import config from "./config";
import { certificates, db } from "@server/db";
import { and, eq, isNotNull, or, inArray, sql } from "drizzle-orm";
import { decryptData } from "@server/lib/encryption";
import * as fs from "fs";
import logger from "@server/logger";
import cache from "@server/lib/cache";
import cache from "#private/lib/cache";
let encryptionKeyHex = "";
let encryptionKey: Buffer;
@@ -55,7 +54,7 @@ export async function getValidCertificatesForDomains(
if (useCache) {
for (const domain of domains) {
const cacheKey = `cert:${domain}`;
const cachedCert = cache.get<CertificateResult>(cacheKey);
const cachedCert = await cache.get<CertificateResult>(cacheKey);
if (cachedCert) {
finalResults.push(cachedCert); // Valid cache hit
} else {
@@ -169,7 +168,7 @@ export async function getValidCertificatesForDomains(
// Add to cache for future requests, using the *requested domain* as the key
if (useCache) {
const cacheKey = `cert:${domain}`;
cache.set(cacheKey, resultCert, 180);
await cache.set(cacheKey, resultCert, 180);
}
}
}

View File

@@ -11,17 +11,17 @@
* This file is not licensed under the AGPLv3.
*/
import { accessAuditLog, db, orgs } from "@server/db";
import { accessAuditLog, logsDb, db, orgs } from "@server/db";
import { getCountryCodeForIp } from "@server/lib/geoip";
import logger from "@server/logger";
import { and, eq, lt } from "drizzle-orm";
import cache from "@server/lib/cache";
import cache from "#private/lib/cache";
import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs";
import { stripPortFromHost } from "@server/lib/ip";
async function getAccessDays(orgId: string): Promise<number> {
// check cache first
const cached = cache.get<number>(`org_${orgId}_accessDays`);
const cached = await cache.get<number>(`org_${orgId}_accessDays`);
if (cached !== undefined) {
return cached;
}
@@ -39,7 +39,7 @@ async function getAccessDays(orgId: string): Promise<number> {
}
// store the result in cache
cache.set(
await cache.set(
`org_${orgId}_accessDays`,
org.settingsLogRetentionDaysAction,
300
@@ -52,7 +52,7 @@ export async function cleanUpOldLogs(orgId: string, retentionDays: number) {
const cutoffTimestamp = calculateCutoffTimestamp(retentionDays);
try {
await db
await logsDb
.delete(accessAuditLog)
.where(
and(
@@ -124,7 +124,7 @@ export async function logAccessAudit(data: {
? await getCountryCodeFromIp(data.requestIp)
: undefined;
await db.insert(accessAuditLog).values({
await logsDb.insert(accessAuditLog).values({
timestamp: timestamp,
orgId: data.orgId,
actorType,
@@ -146,14 +146,14 @@ export async function logAccessAudit(data: {
async function getCountryCodeFromIp(ip: string): Promise<string | undefined> {
const geoIpCacheKey = `geoip_access:${ip}`;
let cachedCountryCode: string | undefined = cache.get(geoIpCacheKey);
let cachedCountryCode: string | undefined = await cache.get(geoIpCacheKey);
if (!cachedCountryCode) {
cachedCountryCode = await getCountryCodeForIp(ip); // do it locally
// Only cache successful lookups to avoid filling cache with undefined values
if (cachedCountryCode) {
// Cache for longer since IP geolocation doesn't change frequently
cache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes
await cache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes
}
}

View File

@@ -38,10 +38,6 @@ export const privateConfigSchema = z.object({
.string()
.optional()
.transform(getEnvOrYaml("SERVER_ENCRYPTION_KEY")),
resend_api_key: z
.string()
.optional()
.transform(getEnvOrYaml("RESEND_API_KEY")),
reo_client_id: z
.string()
.optional()

View File

@@ -1,127 +0,0 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { Resend } from "resend";
import privateConfig from "#private/lib/config";
import logger from "@server/logger";
export enum AudienceIds {
SignUps = "6c4e77b2-0851-4bd6-bac8-f51f91360f1a",
Subscribed = "870b43fd-387f-44de-8fc1-707335f30b20",
Churned = "f3ae92bd-2fdb-4d77-8746-2118afd62549",
Newsletter = "5500c431-191c-42f0-a5d4-8b6d445b4ea0"
}
const resend = new Resend(
privateConfig.getRawPrivateConfig().server.resend_api_key || "missing"
);
export default resend;
export async function moveEmailToAudience(
email: string,
audienceId: AudienceIds
) {
if (process.env.ENVIRONMENT !== "prod") {
logger.debug(
`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`
);
return;
}
const { error, data } = await retryWithBackoff(async () => {
const { data, error } = await resend.contacts.create({
email,
unsubscribed: false,
audienceId
});
if (error) {
throw new Error(
`Error adding email ${email} to audience ${audienceId}: ${error}`
);
}
return { error, data };
});
if (error) {
logger.error(
`Error adding email ${email} to audience ${audienceId}: ${error}`
);
return;
}
if (data) {
logger.debug(
`Added email ${email} to audience ${audienceId} with contact ID ${data.id}`
);
}
const otherAudiences = Object.values(AudienceIds).filter(
(id) => id !== audienceId
);
for (const otherAudienceId of otherAudiences) {
const { error, data } = await retryWithBackoff(async () => {
const { data, error } = await resend.contacts.remove({
email,
audienceId: otherAudienceId
});
if (error) {
throw new Error(
`Error removing email ${email} from audience ${otherAudienceId}: ${error}`
);
}
return { error, data };
});
if (error) {
logger.error(
`Error removing email ${email} from audience ${otherAudienceId}: ${error}`
);
}
if (data) {
logger.info(
`Removed email ${email} from audience ${otherAudienceId}`
);
}
}
}
type RetryOptions = {
retries?: number;
initialDelayMs?: number;
factor?: number;
};
export async function retryWithBackoff<T>(
fn: () => Promise<T>,
options: RetryOptions = {}
): Promise<T> {
const { retries = 5, initialDelayMs = 500, factor = 2 } = options;
let attempt = 0;
let delay = initialDelayMs;
while (true) {
try {
return await fn();
} catch (err) {
attempt++;
if (attempt > retries) throw err;
await new Promise((resolve) => setTimeout(resolve, delay));
delay *= factor;
}
}
}

View File

@@ -1,447 +0,0 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import * as crypto from "crypto";
/**
* SSH CA "Server" - Pure TypeScript Implementation
*
* This module provides basic SSH Certificate Authority functionality using
* only Node.js built-in crypto module. No external dependencies or subprocesses.
*
* Usage:
* 1. generateCA() - Creates a new CA key pair, returns CA info including the
* TrustedUserCAKeys line to add to servers
* 2. signPublicKey() - Signs a user's public key with the CA, returns a certificate
*/
// ============================================================================
// SSH Wire Format Helpers
// ============================================================================
/**
* Encode a string in SSH wire format (4-byte length prefix + data)
*/
function encodeString(data: Buffer | string): Buffer {
const buf = typeof data === "string" ? Buffer.from(data, "utf8") : data;
const len = Buffer.alloc(4);
len.writeUInt32BE(buf.length, 0);
return Buffer.concat([len, buf]);
}
/**
* Encode a uint32 in SSH wire format (big-endian)
*/
function encodeUInt32(value: number): Buffer {
const buf = Buffer.alloc(4);
buf.writeUInt32BE(value, 0);
return buf;
}
/**
* Encode a uint64 in SSH wire format (big-endian)
*/
function encodeUInt64(value: bigint): Buffer {
const buf = Buffer.alloc(8);
buf.writeBigUInt64BE(value, 0);
return buf;
}
/**
* Decode a string from SSH wire format at the given offset
* Returns the string buffer and the new offset
*/
function decodeString(
data: Buffer,
offset: number
): { value: Buffer; newOffset: number } {
const len = data.readUInt32BE(offset);
const value = data.subarray(offset + 4, offset + 4 + len);
return { value, newOffset: offset + 4 + len };
}
// ============================================================================
// SSH Public Key Parsing/Encoding
// ============================================================================
/**
* Parse an OpenSSH public key line (e.g., "ssh-ed25519 AAAA... comment")
*/
function parseOpenSSHPublicKey(pubKeyLine: string): {
keyType: string;
keyData: Buffer;
comment: string;
} {
const parts = pubKeyLine.trim().split(/\s+/);
if (parts.length < 2) {
throw new Error("Invalid public key format");
}
const keyType = parts[0];
const keyData = Buffer.from(parts[1], "base64");
const comment = parts.slice(2).join(" ") || "";
// Verify the key type in the blob matches
const { value: blobKeyType } = decodeString(keyData, 0);
if (blobKeyType.toString("utf8") !== keyType) {
throw new Error(
`Key type mismatch: ${blobKeyType.toString("utf8")} vs ${keyType}`
);
}
return { keyType, keyData, comment };
}
/**
* Encode an Ed25519 public key in OpenSSH format
*/
function encodeEd25519PublicKey(publicKey: Buffer): Buffer {
return Buffer.concat([
encodeString("ssh-ed25519"),
encodeString(publicKey)
]);
}
/**
* Format a public key blob as an OpenSSH public key line
*/
function formatOpenSSHPublicKey(keyBlob: Buffer, comment: string = ""): string {
const { value: keyType } = decodeString(keyBlob, 0);
const base64 = keyBlob.toString("base64");
return `${keyType.toString("utf8")} ${base64}${comment ? " " + comment : ""}`;
}
// ============================================================================
// SSH Certificate Building
// ============================================================================
interface CertificateOptions {
/** Serial number for the certificate */
serial?: bigint;
/** Certificate type: 1 = user, 2 = host */
certType?: number;
/** Key ID (usually username or identifier) */
keyId: string;
/** List of valid principals (usernames the cert is valid for) */
validPrincipals: string[];
/** Valid after timestamp (seconds since epoch) */
validAfter?: bigint;
/** Valid before timestamp (seconds since epoch) */
validBefore?: bigint;
/** Critical options (usually empty for user certs) */
criticalOptions?: Map<string, string>;
/** Extensions to enable */
extensions?: string[];
}
/**
* Build the extensions section of the certificate
*/
function buildExtensions(extensions: string[]): Buffer {
// Extensions are a series of name-value pairs, sorted by name
// For boolean extensions, the value is empty
const sortedExtensions = [...extensions].sort();
const parts: Buffer[] = [];
for (const ext of sortedExtensions) {
parts.push(encodeString(ext));
parts.push(encodeString("")); // Empty value for boolean extensions
}
return encodeString(Buffer.concat(parts));
}
/**
* Build the critical options section
*/
function buildCriticalOptions(options: Map<string, string>): Buffer {
const sortedKeys = [...options.keys()].sort();
const parts: Buffer[] = [];
for (const key of sortedKeys) {
parts.push(encodeString(key));
parts.push(encodeString(encodeString(options.get(key)!)));
}
return encodeString(Buffer.concat(parts));
}
/**
* Build the valid principals section
*/
function buildPrincipals(principals: string[]): Buffer {
const parts: Buffer[] = [];
for (const principal of principals) {
parts.push(encodeString(principal));
}
return encodeString(Buffer.concat(parts));
}
/**
* Extract the raw Ed25519 public key from an OpenSSH public key blob
*/
function extractEd25519PublicKey(keyBlob: Buffer): Buffer {
const { newOffset } = decodeString(keyBlob, 0); // Skip key type
const { value: publicKey } = decodeString(keyBlob, newOffset);
return publicKey;
}
// ============================================================================
// CA Interface
// ============================================================================
export interface CAKeyPair {
/** CA private key in PEM format (keep this secret!) */
privateKeyPem: string;
/** CA public key in PEM format */
publicKeyPem: string;
/** CA public key in OpenSSH format (for TrustedUserCAKeys) */
publicKeyOpenSSH: string;
/** Raw CA public key bytes (Ed25519) */
publicKeyRaw: Buffer;
}
export interface SignedCertificate {
/** The certificate in OpenSSH format (save as id_ed25519-cert.pub or similar) */
certificate: string;
/** The certificate type string */
certType: string;
/** Serial number */
serial: bigint;
/** Key ID */
keyId: string;
/** Valid principals */
validPrincipals: string[];
/** Valid from timestamp */
validAfter: Date;
/** Valid until timestamp */
validBefore: Date;
}
// ============================================================================
// Main Functions
// ============================================================================
/**
* Generate a new SSH Certificate Authority key pair.
*
* Returns the CA keys and the line to add to /etc/ssh/sshd_config:
* TrustedUserCAKeys /etc/ssh/ca.pub
*
* Then save the publicKeyOpenSSH to /etc/ssh/ca.pub on the server.
*
* @param comment - Optional comment for the CA public key
* @returns CA key pair and configuration info
*/
export function generateCA(comment: string = "pangolin-ssh-ca"): CAKeyPair {
// Generate Ed25519 key pair
const { publicKey, privateKey } = crypto.generateKeyPairSync("ed25519", {
publicKeyEncoding: { type: "spki", format: "pem" },
privateKeyEncoding: { type: "pkcs8", format: "pem" }
});
// Get raw public key bytes
const pubKeyObj = crypto.createPublicKey(publicKey);
const rawPubKey = pubKeyObj.export({ type: "spki", format: "der" });
// Ed25519 SPKI format: 12 byte header + 32 byte key
const ed25519PubKey = rawPubKey.subarray(rawPubKey.length - 32);
// Create OpenSSH format public key
const pubKeyBlob = encodeEd25519PublicKey(ed25519PubKey);
const publicKeyOpenSSH = formatOpenSSHPublicKey(pubKeyBlob, comment);
return {
privateKeyPem: privateKey,
publicKeyPem: publicKey,
publicKeyOpenSSH,
publicKeyRaw: ed25519PubKey
};
}
// ============================================================================
// Helper Functions
// ============================================================================
/**
* Get and decrypt the SSH CA keys for an organization.
*
* @param orgId - Organization ID
* @param decryptionKey - Key to decrypt the CA private key (typically server.secret from config)
* @returns CA key pair or null if not found
*/
export async function getOrgCAKeys(
orgId: string,
decryptionKey: string
): Promise<CAKeyPair | null> {
const { db, orgs } = await import("@server/db");
const { eq } = await import("drizzle-orm");
const { decrypt } = await import("@server/lib/crypto");
const [org] = await db
.select({
sshCaPrivateKey: orgs.sshCaPrivateKey,
sshCaPublicKey: orgs.sshCaPublicKey
})
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org || !org.sshCaPrivateKey || !org.sshCaPublicKey) {
return null;
}
const privateKeyPem = decrypt(org.sshCaPrivateKey, decryptionKey);
// Extract raw public key from the OpenSSH format
const { keyData } = parseOpenSSHPublicKey(org.sshCaPublicKey);
const { newOffset } = decodeString(keyData, 0); // Skip key type
const { value: publicKeyRaw } = decodeString(keyData, newOffset);
// Get PEM format of public key
const pubKeyObj = crypto.createPublicKey({
key: privateKeyPem,
format: "pem"
});
const publicKeyPem = pubKeyObj.export({
type: "spki",
format: "pem"
}) as string;
return {
privateKeyPem,
publicKeyPem,
publicKeyOpenSSH: org.sshCaPublicKey,
publicKeyRaw
};
}
/**
* Sign a user's SSH public key with the CA, producing a certificate.
*
* The resulting certificate should be saved alongside the user's private key
* with a -cert.pub suffix. For example:
* - Private key: ~/.ssh/id_ed25519
* - Certificate: ~/.ssh/id_ed25519-cert.pub
*
* @param caPrivateKeyPem - CA private key in PEM format
* @param userPublicKeyLine - User's public key in OpenSSH format
* @param options - Certificate options (principals, validity, etc.)
* @returns Signed certificate
*/
export function signPublicKey(
caPrivateKeyPem: string,
userPublicKeyLine: string,
options: CertificateOptions
): SignedCertificate {
// Parse the user's public key
const { keyType, keyData } = parseOpenSSHPublicKey(userPublicKeyLine);
// Determine certificate type string
let certTypeString: string;
if (keyType === "ssh-ed25519") {
certTypeString = "ssh-ed25519-cert-v01@openssh.com";
} else if (keyType === "ssh-rsa") {
certTypeString = "ssh-rsa-cert-v01@openssh.com";
} else if (keyType === "ecdsa-sha2-nistp256") {
certTypeString = "ecdsa-sha2-nistp256-cert-v01@openssh.com";
} else if (keyType === "ecdsa-sha2-nistp384") {
certTypeString = "ecdsa-sha2-nistp384-cert-v01@openssh.com";
} else if (keyType === "ecdsa-sha2-nistp521") {
certTypeString = "ecdsa-sha2-nistp521-cert-v01@openssh.com";
} else {
throw new Error(`Unsupported key type: ${keyType}`);
}
// Get CA public key from private key
const caPrivKey = crypto.createPrivateKey(caPrivateKeyPem);
const caPubKey = crypto.createPublicKey(caPrivKey);
const caRawPubKey = caPubKey.export({ type: "spki", format: "der" });
const caEd25519PubKey = caRawPubKey.subarray(caRawPubKey.length - 32);
const caPubKeyBlob = encodeEd25519PublicKey(caEd25519PubKey);
// Set defaults
const serial = options.serial ?? BigInt(Date.now());
const certType = options.certType ?? 1; // 1 = user cert
const now = BigInt(Math.floor(Date.now() / 1000));
const validAfter = options.validAfter ?? now - 60n; // 1 minute ago
const validBefore = options.validBefore ?? now + 86400n * 365n; // 1 year from now
// Default extensions for user certificates
const defaultExtensions = [
"permit-X11-forwarding",
"permit-agent-forwarding",
"permit-port-forwarding",
"permit-pty",
"permit-user-rc"
];
const extensions = options.extensions ?? defaultExtensions;
const criticalOptions = options.criticalOptions ?? new Map();
// Generate nonce (random bytes)
const nonce = crypto.randomBytes(32);
// Extract the public key portion from the user's key blob
// For Ed25519: skip the key type string, get the public key (already encoded)
let userKeyPortion: Buffer;
if (keyType === "ssh-ed25519") {
// Skip the key type string, take the rest (which is encodeString(32-byte-key))
const { newOffset } = decodeString(keyData, 0);
userKeyPortion = keyData.subarray(newOffset);
} else {
// For other key types, extract everything after the key type
const { newOffset } = decodeString(keyData, 0);
userKeyPortion = keyData.subarray(newOffset);
}
// Build the certificate body (to be signed)
const certBody = Buffer.concat([
encodeString(certTypeString),
encodeString(nonce),
userKeyPortion,
encodeUInt64(serial),
encodeUInt32(certType),
encodeString(options.keyId),
buildPrincipals(options.validPrincipals),
encodeUInt64(validAfter),
encodeUInt64(validBefore),
buildCriticalOptions(criticalOptions),
buildExtensions(extensions),
encodeString(""), // reserved
encodeString(caPubKeyBlob) // signature key (CA public key)
]);
// Sign the certificate body
const signature = crypto.sign(null, certBody, caPrivKey);
// Build the full signature blob (algorithm + signature)
const signatureBlob = Buffer.concat([
encodeString("ssh-ed25519"),
encodeString(signature)
]);
// Build complete certificate
const certificate = Buffer.concat([certBody, encodeString(signatureBlob)]);
// Format as OpenSSH certificate line
const certLine = `${certTypeString} ${certificate.toString("base64")} ${options.keyId}`;
return {
certificate: certLine,
certType: certTypeString,
serial,
keyId: options.keyId,
validPrincipals: options.validPrincipals,
validAfter: new Date(Number(validAfter) * 1000),
validBefore: new Date(Number(validBefore) * 1000)
};
}

View File

@@ -34,7 +34,11 @@ import {
import logger from "@server/logger";
import config from "@server/lib/config";
import { orgs, resources, sites, Target, targets } from "@server/db";
import { sanitize, validatePathRewriteConfig } from "@server/lib/traefik/utils";
import {
sanitize,
encodePath,
validatePathRewriteConfig
} from "@server/lib/traefik/utils";
import privateConfig from "#private/lib/config";
import createPathRewriteMiddleware from "@server/lib/traefik/middleware";
import {
@@ -170,7 +174,7 @@ export async function getTraefikConfig(
resourcesWithTargetsAndSites.forEach((row) => {
const resourceId = row.resourceId;
const resourceName = sanitize(row.resourceName) || "";
const targetPath = sanitize(row.path) || ""; // Handle null/undefined paths
const targetPath = encodePath(row.path); // Use encodePath to avoid collisions (e.g. "/a/b" vs "/a-b")
const pathMatchType = row.pathMatchType || "";
const rewritePath = row.rewritePath || "";
const rewritePathType = row.rewritePathType || "";
@@ -192,7 +196,7 @@ export async function getTraefikConfig(
const mapKey = [resourceId, pathKey].filter(Boolean).join("-");
const key = sanitize(mapKey);
if (!resourcesMap.has(key)) {
if (!resourcesMap.has(mapKey)) {
const validation = validatePathRewriteConfig(
row.path,
row.pathMatchType,
@@ -207,9 +211,10 @@ export async function getTraefikConfig(
return;
}
resourcesMap.set(key, {
resourcesMap.set(mapKey, {
resourceId: row.resourceId,
name: resourceName,
key: key,
fullDomain: row.fullDomain,
ssl: row.ssl,
http: row.http,
@@ -243,7 +248,7 @@ export async function getTraefikConfig(
}
// Add target with its associated site data
resourcesMap.get(key).targets.push({
resourcesMap.get(mapKey).targets.push({
resourceId: row.resourceId,
targetId: row.targetId,
ip: row.ip,
@@ -296,8 +301,9 @@ export async function getTraefikConfig(
};
// get the key and the resource
for (const [key, resource] of resourcesMap.entries()) {
for (const [, resource] of resourcesMap.entries()) {
const targets = resource.targets as TargetWithSite[];
const key = resource.key;
const routerName = `${key}-${resource.name}-router`;
const serviceName = `${key}-${resource.name}-service`;
@@ -665,7 +671,10 @@ export async function getTraefikConfig(
// TODO: HOW TO HANDLE ^^^^^^ BETTER
const anySitesOnline = targets.some(
(target) => target.site.online
(target) =>
target.site.online ||
target.site.type === "local" ||
target.site.type === "wireguard"
);
return (
@@ -793,7 +802,10 @@ export async function getTraefikConfig(
servers: (() => {
// Check if any sites are online
const anySitesOnline = targets.some(
(target) => target.site.online
(target) =>
target.site.online ||
target.site.type === "local" ||
target.site.type === "wireguard"
);
return targets

View File

@@ -12,18 +12,18 @@
*/
import { ActionsEnum } from "@server/auth/actions";
import { actionAuditLog, db, orgs } from "@server/db";
import { actionAuditLog, logsDb, db, orgs } from "@server/db";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors";
import { and, eq, lt } from "drizzle-orm";
import cache from "@server/lib/cache";
import cache from "#private/lib/cache";
import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs";
async function getActionDays(orgId: string): Promise<number> {
// check cache first
const cached = cache.get<number>(`org_${orgId}_actionDays`);
const cached = await cache.get<number>(`org_${orgId}_actionDays`);
if (cached !== undefined) {
return cached;
}
@@ -41,7 +41,7 @@ async function getActionDays(orgId: string): Promise<number> {
}
// store the result in cache
cache.set(
await cache.set(
`org_${orgId}_actionDays`,
org.settingsLogRetentionDaysAction,
300
@@ -54,7 +54,7 @@ export async function cleanUpOldLogs(orgId: string, retentionDays: number) {
const cutoffTimestamp = calculateCutoffTimestamp(retentionDays);
try {
await db
await logsDb
.delete(actionAuditLog)
.where(
and(
@@ -123,7 +123,7 @@ export function logActionAudit(action: ActionsEnum) {
metadata = JSON.stringify(req.params);
}
await db.insert(actionAuditLog).values({
await logsDb.insert(actionAuditLog).values({
timestamp,
orgId,
actorType,

View File

@@ -32,7 +32,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/logs/access/export",
description: "Export the access audit log for an organization as CSV",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Logs],
request: {
query: queryAccessAuditLogsQuery,
params: queryAccessAuditLogsParams

View File

@@ -32,7 +32,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/logs/action/export",
description: "Export the action audit log for an organization as CSV",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Logs],
request: {
query: queryActionAuditLogsQuery,
params: queryActionAuditLogsParams

View File

@@ -11,11 +11,11 @@
* This file is not licensed under the AGPLv3.
*/
import { accessAuditLog, db, resources } from "@server/db";
import { accessAuditLog, logsDb, resources, db, primaryDb } from "@server/db";
import { registry } from "@server/openApi";
import { NextFunction } from "express";
import { Request, Response } from "express";
import { eq, gt, lt, and, count, desc } from "drizzle-orm";
import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm";
import { OpenAPITags } from "@server/openApi";
import { z } from "zod";
import createHttpError from "http-errors";
@@ -115,15 +115,13 @@ function getWhere(data: Q) {
}
export function queryAccess(data: Q) {
return db
return logsDb
.select({
orgId: accessAuditLog.orgId,
action: accessAuditLog.action,
actorType: accessAuditLog.actorType,
actorId: accessAuditLog.actorId,
resourceId: accessAuditLog.resourceId,
resourceName: resources.name,
resourceNiceId: resources.niceId,
ip: accessAuditLog.ip,
location: accessAuditLog.location,
userAgent: accessAuditLog.userAgent,
@@ -133,16 +131,46 @@ export function queryAccess(data: Q) {
actor: accessAuditLog.actor
})
.from(accessAuditLog)
.leftJoin(
resources,
eq(accessAuditLog.resourceId, resources.resourceId)
)
.where(getWhere(data))
.orderBy(desc(accessAuditLog.timestamp), desc(accessAuditLog.id));
}
async function enrichWithResourceDetails(logs: Awaited<ReturnType<typeof queryAccess>>) {
// If logs database is the same as main database, we can do a join
// Otherwise, we need to fetch resource details separately
const resourceIds = logs
.map(log => log.resourceId)
.filter((id): id is number => id !== null && id !== undefined);
if (resourceIds.length === 0) {
return logs.map(log => ({ ...log, resourceName: null, resourceNiceId: null }));
}
// Fetch resource details from main database
const resourceDetails = await primaryDb
.select({
resourceId: resources.resourceId,
name: resources.name,
niceId: resources.niceId
})
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
// Create a map for quick lookup
const resourceMap = new Map(
resourceDetails.map(r => [r.resourceId, { name: r.name, niceId: r.niceId }])
);
// Enrich logs with resource details
return logs.map(log => ({
...log,
resourceName: log.resourceId ? resourceMap.get(log.resourceId)?.name ?? null : null,
resourceNiceId: log.resourceId ? resourceMap.get(log.resourceId)?.niceId ?? null : null
}));
}
export function countAccessQuery(data: Q) {
const countQuery = db
const countQuery = logsDb
.select({ count: count() })
.from(accessAuditLog)
.where(getWhere(data));
@@ -161,7 +189,7 @@ async function queryUniqueFilterAttributes(
);
// Get unique actors
const uniqueActors = await db
const uniqueActors = await logsDb
.selectDistinct({
actor: accessAuditLog.actor
})
@@ -169,7 +197,7 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
// Get unique locations
const uniqueLocations = await db
const uniqueLocations = await logsDb
.selectDistinct({
locations: accessAuditLog.location
})
@@ -177,25 +205,40 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
// Get unique resources with names
const uniqueResources = await db
const uniqueResources = await logsDb
.selectDistinct({
id: accessAuditLog.resourceId,
name: resources.name
id: accessAuditLog.resourceId
})
.from(accessAuditLog)
.leftJoin(
resources,
eq(accessAuditLog.resourceId, resources.resourceId)
)
.where(baseConditions);
// Fetch resource names from main database for the unique resource IDs
const resourceIds = uniqueResources
.map(row => row.id)
.filter((id): id is number => id !== null);
let resourcesWithNames: Array<{ id: number; name: string | null }> = [];
if (resourceIds.length > 0) {
const resourceDetails = await primaryDb
.select({
resourceId: resources.resourceId,
name: resources.name
})
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
resourcesWithNames = resourceDetails.map(r => ({
id: r.resourceId,
name: r.name
}));
}
return {
actors: uniqueActors
.map((row) => row.actor)
.filter((actor): actor is string => actor !== null),
resources: uniqueResources.filter(
(row): row is { id: number; name: string | null } => row.id !== null
),
resources: resourcesWithNames,
locations: uniqueLocations
.map((row) => row.locations)
.filter((location): location is string => location !== null)
@@ -206,7 +249,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/logs/access",
description: "Query the access audit log for an organization",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Logs],
request: {
query: queryAccessAuditLogsQuery,
params: queryAccessAuditLogsParams
@@ -243,7 +286,10 @@ export async function queryAccessAuditLogs(
const baseQuery = queryAccess(data);
const log = await baseQuery.limit(data.limit).offset(data.offset);
const logsRaw = await baseQuery.limit(data.limit).offset(data.offset);
// Enrich with resource details (handles cross-database scenario)
const log = await enrichWithResourceDetails(logsRaw);
const totalCountResult = await countAccessQuery(data);
const totalCount = totalCountResult[0].count;

View File

@@ -11,7 +11,7 @@
* This file is not licensed under the AGPLv3.
*/
import { actionAuditLog, db } from "@server/db";
import { actionAuditLog, logsDb } from "@server/db";
import { registry } from "@server/openApi";
import { NextFunction } from "express";
import { Request, Response } from "express";
@@ -97,7 +97,7 @@ function getWhere(data: Q) {
}
export function queryAction(data: Q) {
return db
return logsDb
.select({
orgId: actionAuditLog.orgId,
action: actionAuditLog.action,
@@ -113,7 +113,7 @@ export function queryAction(data: Q) {
}
export function countActionQuery(data: Q) {
const countQuery = db
const countQuery = logsDb
.select({ count: count() })
.from(actionAuditLog)
.where(getWhere(data));
@@ -132,14 +132,14 @@ async function queryUniqueFilterAttributes(
);
// Get unique actors
const uniqueActors = await db
const uniqueActors = await logsDb
.selectDistinct({
actor: actionAuditLog.actor
})
.from(actionAuditLog)
.where(baseConditions);
const uniqueActions = await db
const uniqueActions = await logsDb
.selectDistinct({
action: actionAuditLog.action
})
@@ -160,7 +160,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/logs/action",
description: "Query the action audit log for an organization",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Logs],
request: {
query: queryActionAuditLogsQuery,
params: queryActionAuditLogsParams

View File

@@ -31,16 +31,16 @@ const getOrgSchema = z.strictObject({
orgId: z.string()
});
registry.registerPath({
method: "get",
path: "/org/{orgId}/billing/usage",
description: "Get an organization's billing usage",
tags: [OpenAPITags.Org],
request: {
params: getOrgSchema
},
responses: {}
});
// registry.registerPath({
// method: "get",
// path: "/org/{orgId}/billing/usage",
// description: "Get an organization's billing usage",
// tags: [OpenAPITags.Org],
// request: {
// params: getOrgSchema
// },
// responses: {}
// });
export async function getOrgUsage(
req: Request,

View File

@@ -24,7 +24,6 @@ import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import stripe from "#private/lib/stripe";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
import { getSubType } from "./getSubType";
import privateConfig from "#private/lib/config";
import { getLicensePriceSet, LicenseId } from "@server/lib/billing/licenses";
@@ -172,7 +171,7 @@ export async function handleSubscriptionCreated(
const email = orgUserRes.user.email;
if (email) {
moveEmailToAudience(email, AudienceIds.Subscribed);
// TODO: update user in Sendy
}
}
} else if (type === "license") {

View File

@@ -23,7 +23,6 @@ import {
import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
import { getSubType } from "./getSubType";
import stripe from "#private/lib/stripe";
import privateConfig from "#private/lib/config";
@@ -109,7 +108,7 @@ export async function handleSubscriptionDeleted(
const email = orgUserRes.user.email;
if (email) {
moveEmailToAudience(email, AudienceIds.Churned);
// TODO: update user in Sendy
}
}
} else if (type === "license") {

View File

@@ -480,9 +480,9 @@ authenticated.get(
authenticated.post(
"/re-key/:clientId/regenerate-client-secret",
verifyClientAccess, // this is first to set the org id
verifyValidLicense,
verifyValidSubscription(tierMatrix.rotateCredentials),
verifyClientAccess, // this is first to set the org id
verifyLimits,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateClientSecret
@@ -490,9 +490,9 @@ authenticated.post(
authenticated.post(
"/re-key/:siteId/regenerate-site-secret",
verifySiteAccess, // this is first to set the org id
verifyValidLicense,
verifyValidSubscription(tierMatrix.rotateCredentials),
verifySiteAccess, // this is first to set the org id
verifyLimits,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateSiteSecret
@@ -515,6 +515,6 @@ authenticated.post(
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.signSshKey),
logActionAudit(ActionsEnum.signSshKey),
// logActionAudit(ActionsEnum.signSshKey), // it is handled inside of the function below so we can include more metadata
ssh.signSshKey
);

View File

@@ -15,6 +15,7 @@ import { verifySessionRemoteExitNodeMiddleware } from "#private/middlewares/veri
import { Router } from "express";
import {
db,
logsDb,
exitNodes,
Resource,
ResourcePassword,
@@ -81,6 +82,7 @@ import { verifyResourceAccessToken } from "@server/auth/verifyResourceAccessToke
import semver from "semver";
import { maxmindAsnLookup } from "@server/db/maxmindAsn";
import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy";
import { sanitizeString } from "@server/lib/sanitize";
// Zod schemas for request validation
const getResourceByDomainParamsSchema = z.strictObject({
@@ -1859,24 +1861,24 @@ hybridRouter.post(
})
.map((logEntry) => ({
timestamp: logEntry.timestamp,
orgId: logEntry.orgId,
actorType: logEntry.actorType,
actor: logEntry.actor,
actorId: logEntry.actorId,
metadata: logEntry.metadata,
orgId: sanitizeString(logEntry.orgId),
actorType: sanitizeString(logEntry.actorType),
actor: sanitizeString(logEntry.actor),
actorId: sanitizeString(logEntry.actorId),
metadata: sanitizeString(logEntry.metadata),
action: logEntry.action,
resourceId: logEntry.resourceId,
reason: logEntry.reason,
location: logEntry.location,
location: sanitizeString(logEntry.location),
// userAgent: data.userAgent, // TODO: add this
// headers: data.body.headers,
// query: data.body.query,
originalRequestURL: logEntry.originalRequestURL,
scheme: logEntry.scheme,
host: logEntry.host,
path: logEntry.path,
method: logEntry.method,
ip: logEntry.ip,
originalRequestURL: sanitizeString(logEntry.originalRequestURL) ?? "",
scheme: sanitizeString(logEntry.scheme) ?? "",
host: sanitizeString(logEntry.host) ?? "",
path: sanitizeString(logEntry.path) ?? "",
method: sanitizeString(logEntry.method) ?? "",
ip: sanitizeString(logEntry.ip),
tls: logEntry.tls
}));
@@ -1884,7 +1886,7 @@ hybridRouter.post(
const batchSize = 100;
for (let i = 0; i < logEntries.length; i += batchSize) {
const batch = logEntries.slice(i, i + batchSize);
await db.insert(requestAuditLog).values(batch);
await logsDb.insert(requestAuditLog).values(batch);
}
return response(res, {

View File

@@ -52,7 +52,7 @@ registry.registerPath({
method: "put",
path: "/org/{orgId}/idp/oidc",
description: "Create an OIDC IdP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
tags: [OpenAPITags.OrgIdp],
request: {
params: paramsSchema,
body: {

View File

@@ -35,7 +35,7 @@ registry.registerPath({
method: "delete",
path: "/org/{orgId}/idp/{idpId}",
description: "Delete IDP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
tags: [OpenAPITags.OrgIdp],
request: {
params: paramsSchema
},

View File

@@ -50,9 +50,9 @@ async function query(idpId: number, orgId: string) {
registry.registerPath({
method: "get",
path: "/org/:orgId/idp/:idpId",
path: "/org/{orgId}/idp/{idpId}",
description: "Get an IDP by its IDP ID for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
tags: [OpenAPITags.OrgIdp],
request: {
params: paramsSchema
},

View File

@@ -67,7 +67,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/idp",
description: "List all IDP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
tags: [OpenAPITags.OrgIdp],
request: {
query: querySchema,
params: paramsSchema

View File

@@ -59,7 +59,7 @@ registry.registerPath({
method: "post",
path: "/org/{orgId}/idp/{idpId}/oidc",
description: "Update an OIDC IdP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
tags: [OpenAPITags.OrgIdp],
request: {
params: paramsSchema,
body: {

View File

@@ -38,7 +38,7 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
);
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
const newlyOfflineNodes = await db
const offlineNodes = await db
.update(exitNodes)
.set({ online: false })
.where(
@@ -53,32 +53,15 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
)
.returning();
// Update the sites to offline if they have not pinged either
const exitNodeIds = newlyOfflineNodes.map(
(node) => node.exitNodeId
);
const sitesOnNode = await db
.select()
.from(sites)
.where(
and(
eq(sites.online, true),
inArray(sites.exitNodeId, exitNodeIds)
)
if (offlineNodes.length > 0) {
logger.info(
`checkRemoteExitNodeOffline: Marked ${offlineNodes.length} remoteExitNode client(s) offline due to inactivity`
);
// loop through the sites and process their lastBandwidthUpdate as an iso string and if its more than 1 minute old then mark the site offline
for (const site of sitesOnNode) {
if (!site.lastBandwidthUpdate) {
continue;
}
const lastBandwidthUpdate = new Date(site.lastBandwidthUpdate);
if (Date.now() - lastBandwidthUpdate.getTime() > 60 * 1000) {
await db
.update(sites)
.set({ online: false })
.where(eq(sites.siteId, site.siteId));
for (const offlineClient of offlineNodes) {
logger.debug(
`checkRemoteExitNodeOffline: Client ${offlineClient.exitNodeId} marked offline (lastPing: ${offlineClient.lastPing})`
);
}
}
} catch (error) {

View File

@@ -52,7 +52,7 @@ registry.registerPath({
method: "get",
path: "/maintenance/info",
description: "Get maintenance information for a resource by domain.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
query: z.object({
fullDomain: z.string()

View File

@@ -14,7 +14,9 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import {
actionAuditLog,
db,
logsDb,
newts,
roles,
roundTripMessageTracker,
@@ -29,12 +31,12 @@ import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { and, eq, inArray, or } from "drizzle-orm";
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
import { signPublicKey, getOrgCAKeys } from "#private/lib/sshCA";
import { signPublicKey, getOrgCAKeys } from "@server/lib/sshCA";
import config from "@server/lib/config";
import { sendToClient } from "#private/routers/ws";
import { ActionsEnum } from "@server/auth/actions";
const paramsSchema = z.strictObject({
orgId: z.string().nonempty()
@@ -64,6 +66,7 @@ export type SignSshKeyResponse = {
sshUsername: string;
sshHost: string;
resourceId: number;
siteId: number;
keyId: string;
validPrincipals: string[];
validAfter: string;
@@ -185,7 +188,7 @@ export async function signSshKey(
} else if (req.user?.username) {
usernameToUse = req.user.username;
// We need to clean out any spaces or special characters from the username to ensure it's valid for SSH certificates
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "");
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "-");
if (!usernameToUse) {
return next(
createHttpError(
@@ -203,6 +206,9 @@ export async function signSshKey(
);
}
// prefix with p-
usernameToUse = `p-${usernameToUse}`;
// check if we have a existing user in this org with the same
const [existingUserWithSameName] = await db
.select()
@@ -248,6 +254,16 @@ export async function signSshKey(
);
}
}
await db
.update(userOrgs)
.set({ pamUsername: usernameToUse })
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.userId, userId)
)
);
} else {
usernameToUse = userOrg.pamUsername;
}
@@ -319,7 +335,16 @@ export async function signSshKey(
);
}
// Check if the user has access to the resource (any of their roles)
if (resource.mode == "cidr") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"SSHing is not supported for CIDR resources"
)
);
}
// Check if the user has access to the resource
const hasAccess = await canUserAccessSiteResource({
userId: userId,
resourceId: resource.siteResourceId,
@@ -444,6 +469,20 @@ export async function signSshKey(
sshHost = resource.destination;
}
await logsDb.insert(actionAuditLog).values({
timestamp: Math.floor(Date.now() / 1000),
orgId: orgId,
actorType: "user",
actor: req.user?.username ?? "",
actorId: req.user?.userId ?? "",
action: ActionsEnum.signSshKey,
metadata: JSON.stringify({
resourceId: resource.siteResourceId,
resource: resource.name,
siteId: resource.siteId,
})
});
return response<SignSshKeyResponse>(res, {
data: {
certificate: cert.certificate,
@@ -451,6 +490,7 @@ export async function signSshKey(
sshUsername: usernameToUse,
sshHost: sshHost,
resourceId: resource.siteResourceId,
siteId: resource.siteId,
keyId: cert.keyId,
validPrincipals: cert.validPrincipals,
validAfter: cert.validAfter.toISOString(),

View File

@@ -17,10 +17,13 @@ import {
startRemoteExitNodeOfflineChecker
} from "#private/routers/remoteExitNode";
import { MessageHandler } from "@server/routers/ws";
import { build } from "@server/build";
export const messageHandlers: Record<string, MessageHandler> = {
"remoteExitNode/register": handleRemoteExitNodeRegisterMessage,
"remoteExitNode/ping": handleRemoteExitNodePingMessage
};
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
if (build != "saas") {
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
}

View File

@@ -12,6 +12,7 @@
*/
import { Router, Request, Response } from "express";
import zlib from "zlib";
import { Server as HttpServer } from "http";
import { WebSocket, WebSocketServer } from "ws";
import { Socket } from "net";
@@ -24,7 +25,8 @@ import {
OlmSession,
RemoteExitNode,
RemoteExitNodeSession,
remoteExitNodes
remoteExitNodes,
sites
} from "@server/db";
import { eq } from "drizzle-orm";
import { db } from "@server/db";
@@ -57,11 +59,13 @@ const MAX_PENDING_MESSAGES = 50; // Maximum messages to queue during connection
const processMessage = async (
ws: AuthenticatedWebSocket,
data: Buffer,
isBinary: boolean,
clientId: string,
clientType: ClientType
): Promise<void> => {
try {
const message: WSMessage = JSON.parse(data.toString());
const messageBuffer = isBinary ? zlib.gunzipSync(data) : data;
const message: WSMessage = JSON.parse(messageBuffer.toString());
// logger.debug(
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
@@ -76,7 +80,7 @@ const processMessage = async (
clientId,
message.type, // Pass message type for granular limiting
100, // max requests per window
20, // max requests per message type per window
100, // max requests per message type per window
60 * 1000 // window in milliseconds
);
if (rateLimitResult.isLimited) {
@@ -163,8 +167,16 @@ const processPendingMessages = async (
);
const jobs = [];
for (const messageData of ws.pendingMessages) {
jobs.push(processMessage(ws, messageData, clientId, clientType));
for (const pending of ws.pendingMessages) {
jobs.push(
processMessage(
ws,
pending.data,
pending.isBinary,
clientId,
clientType
)
);
}
await Promise.all(jobs);
@@ -185,6 +197,12 @@ const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
// Config version tracking map (local to this node, resets on server restart)
const clientConfigVersions: Map<string, number> = new Map();
// Tracks the last Unix timestamp (seconds) at which a ping was flushed to the
// DB for a given siteId. Resets on server restart which is fine the first
// ping after startup will always write, re-establishing the online state.
const lastPingDbWrite: Map<number, number> = new Map();
const PING_DB_WRITE_INTERVAL = 45; // seconds
// Recovery tracking
let isRedisRecoveryInProgress = false;
@@ -325,7 +343,9 @@ const addClient = async (
// Check Redis first if enabled
if (redisManager.isRedisEnabled()) {
try {
const redisVersion = await redisManager.get(getConfigVersionKey(clientId));
const redisVersion = await redisManager.get(
getConfigVersionKey(clientId)
);
if (redisVersion !== null) {
configVersion = parseInt(redisVersion, 10);
// Sync to local cache
@@ -337,7 +357,10 @@ const addClient = async (
} else {
// Use local cache version and sync to Redis
configVersion = clientConfigVersions.get(clientId) || 0;
await redisManager.set(getConfigVersionKey(clientId), configVersion.toString());
await redisManager.set(
getConfigVersionKey(clientId),
configVersion.toString()
);
}
} catch (error) {
logger.error("Failed to get/set config version in Redis:", error);
@@ -432,7 +455,9 @@ const removeClient = async (
};
// Helper to get the current config version for a client
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
const getClientConfigVersion = async (
clientId: string
): Promise<number | undefined> => {
// Try Redis first if available
if (redisManager.isRedisEnabled()) {
try {
@@ -502,11 +527,26 @@ const sendToClientLocal = async (
};
const messageString = JSON.stringify(messageWithVersion);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(messageString);
}
});
if (options.compress) {
logger.debug(
`Message size before compression: ${messageString.length} bytes`
);
const compressed = zlib.gzipSync(Buffer.from(messageString, "utf8"));
logger.debug(
`Message size after compression: ${compressed.length} bytes`
);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(messageString);
}
});
}
return true;
};
@@ -532,11 +572,22 @@ const broadcastToAllExceptLocal = async (
configVersion
};
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion));
}
});
if (options.compress) {
const compressed = zlib.gzipSync(
Buffer.from(JSON.stringify(messageWithVersion), "utf8")
);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion));
}
});
}
}
}
};
@@ -762,7 +813,7 @@ const setupConnection = async (
}
// Set up message handler FIRST to prevent race condition
ws.on("message", async (data) => {
ws.on("message", async (data, isBinary) => {
if (!ws.isFullyConnected) {
// Queue message for later processing with limits
ws.pendingMessages = ws.pendingMessages || [];
@@ -777,11 +828,17 @@ const setupConnection = async (
logger.debug(
`Queueing message from ${clientType.toUpperCase()} ID: ${clientId} (connection not fully established)`
);
ws.pendingMessages.push(data as Buffer);
ws.pendingMessages.push({ data: data as Buffer, isBinary });
return;
}
await processMessage(ws, data as Buffer, clientId, clientType);
await processMessage(
ws,
data as Buffer,
isBinary,
clientId,
clientType
);
});
// Set up other event handlers before async operations
@@ -796,6 +853,35 @@ const setupConnection = async (
);
});
// Handle WebSocket protocol-level pings from older newt clients that do
// not send application-level "newt/ping" messages. Update the site's
// online state and lastPing timestamp so the offline checker treats them
// the same as modern newt clients.
if (clientType === "newt") {
const newtClient = client as Newt;
ws.on("ping", async () => {
if (!newtClient.siteId) return;
const now = Math.floor(Date.now() / 1000);
const lastWrite = lastPingDbWrite.get(newtClient.siteId) ?? 0;
if (now - lastWrite < PING_DB_WRITE_INTERVAL) return;
lastPingDbWrite.set(newtClient.siteId, now);
try {
await db
.update(sites)
.set({
online: true,
lastPing: now
})
.where(eq(sites.siteId, newtClient.siteId));
} catch (error) {
logger.error(
"Error updating newt site online state on WS ping",
{ error }
);
}
});
}
ws.on("error", (error: Error) => {
logger.error(
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,