Compare commits

..

1 Commits

Author SHA1 Message Date
Owen
02c68b6cd3 Reconnect newts when a exit node comes back online 2026-05-20 15:38:40 -07:00
19 changed files with 400 additions and 1030 deletions

View File

@@ -154,19 +154,8 @@ class AdaptiveCache {
keys(): string[] {
return localCache.keys();
}
/**
* Get keys with a specific prefix
* @param prefix - Key prefix to match
* @returns Array of matching keys
*/
async keysWithPrefix(prefix: string): Promise<string[]> {
const allKeys = localCache.keys();
return allKeys.filter((key) => key.startsWith(prefix));
}
}
// Export singleton instance
export const cache = new AdaptiveCache();
export const regionalCache = cache; // Alias for compatability with the private version
export default cache;

View File

@@ -18,7 +18,7 @@ import {
userOrgRoles,
userSiteResources
} from "@server/db";
import { and, count, eq, inArray, ne } from "drizzle-orm";
import { and, eq, inArray, ne } from "drizzle-orm";
import { deletePeer as newtDeletePeer } from "@server/routers/newt/peers";
import {
@@ -39,11 +39,6 @@ import {
removePeerData,
removeTargets as removeSubnetProxyTargets
} from "@server/routers/client/targets";
import { lockManager } from "#dynamic/lib/lock";
// TTL for rebuild-association locks. These functions can fan out into many
// peer/proxy updates, so give them a generous window.
const REBUILD_ASSOCIATIONS_LOCK_TTL_MS = 120000;
export async function getClientSiteResourceAccess(
siteResource: SiteResource,
@@ -166,23 +161,6 @@ export async function rebuildClientAssociationsFromSiteResource(
pubKey: string | null;
subnet: string | null;
}[];
}> {
return await lockManager.withLock(
`rebuild-client-associations:site-resource:${siteResource.siteResourceId}`,
() => rebuildClientAssociationsFromSiteResourceImpl(siteResource, trx),
REBUILD_ASSOCIATIONS_LOCK_TTL_MS
);
}
async function rebuildClientAssociationsFromSiteResourceImpl(
siteResource: SiteResource,
trx: Transaction | typeof db = db
): Promise<{
mergedAllClients: {
clientId: number;
pubKey: string | null;
subnet: string | null;
}[];
}> {
logger.debug(
`rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] START siteResourceId=${siteResource.siteResourceId} networkId=${siteResource.networkId} orgId=${siteResource.orgId}`
@@ -561,29 +539,6 @@ async function handleMessagesForSiteClients(
}
}
// get the number of sites on each of these clients so we can log it and make decisions about whether to send messages based on it
const clientSiteCounts: Record<number, number> = {};
if (clientsToProcess.size > 0) {
const clientIdsToProcess = Array.from(clientsToProcess.keys());
const siteCounts = await trx
.select({
clientId: clientSitesAssociationsCache.clientId,
siteCount: count(clientSitesAssociationsCache.siteId)
})
.from(clientSitesAssociationsCache)
.where(
inArray(
clientSitesAssociationsCache.clientId,
clientIdsToProcess
)
)
.groupBy(clientSitesAssociationsCache.clientId);
for (const row of siteCounts) {
clientSiteCounts[row.clientId] = Number(row.siteCount);
}
}
for (const client of clientsToProcess.values()) {
// UPDATE THE NEWT
if (!client.subnet || !client.pubKey) {
@@ -627,14 +582,7 @@ async function handleMessagesForSiteClients(
}
if (isAdd) {
if (clientSiteCounts[client.clientId] > 250) {
// skip adding the peer if we have more than 250 sites because we are in jit mode anyway
logger.info(
`rebuildClientAssociations: Client ${client.clientId} has ${clientSiteCounts[client.clientId]} sites so skipping adding peer to newt and olm because it is likely in jit mode`
);
continue;
}
// TODO: if we are in jit mode here should we really be sending this?
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
@@ -652,24 +600,9 @@ async function handleMessagesForSiteClients(
exitNodeJobs.push(updateClientSiteDestinations(client, trx));
}
Promise.all(exitNodeJobs).catch((error) => {
logger.error(
`rebuildClientAssociations: Error updating client site destinations for site ${site.siteId}:`,
error
);
});
Promise.all(newtJobs).catch((error) => {
logger.error(
`rebuildClientAssociations: Error updating Newt peers for site ${site.siteId}:`,
error
);
});
Promise.all(olmJobs).catch((error) => {
logger.error(
`rebuildClientAssociations: Error updating Olm peers for site ${site.siteId}:`,
error
);
});
await Promise.all(exitNodeJobs);
await Promise.all(newtJobs); // do the servers first to make sure they are ready?
await Promise.all(olmJobs);
}
interface PeerDestination {
@@ -952,17 +885,6 @@ async function handleSubnetProxyTargetUpdates(
export async function rebuildClientAssociationsFromClient(
client: Client,
trx: Transaction | typeof db = db
): Promise<void> {
return await lockManager.withLock(
`rebuild-client-associations:client:${client.clientId}`,
() => rebuildClientAssociationsFromClientImpl(client, trx),
REBUILD_ASSOCIATIONS_LOCK_TTL_MS
);
}
async function rebuildClientAssociationsFromClientImpl(
client: Client,
trx: Transaction | typeof db = db
): Promise<void> {
let newSiteResourceIds: number[] = [];
@@ -1235,12 +1157,6 @@ async function handleMessagesForClientSites(
const olmJobs: Promise<any>[] = [];
const exitNodeJobs: Promise<any>[] = [];
const totalSitesOnClient = await trx
.select({ count: count(clientSitesAssociationsCache.siteId) })
.from(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId))
.then((rows) => Number(rows[0].count));
for (const siteData of sitesData) {
const site = siteData.sites;
const exitNode = siteData.exitNodes;
@@ -1301,14 +1217,7 @@ async function handleMessagesForClientSites(
continue;
}
if (totalSitesOnClient > 250) {
// skip adding the site if we have more than 250 because we are in jit mode anyway
logger.info(
`rebuildClientAssociations: Client ${client.clientId} has ${totalSitesOnClient} sites so skipping adding peer to newt and olm because it is likely in jit mode`
);
continue;
}
// TODO: if we are in jit mode here should we really be sending this?
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
@@ -1336,24 +1245,9 @@ async function handleMessagesForClientSites(
);
}
Promise.all(exitNodeJobs).catch((error) => {
logger.error(
`rebuildClientAssociations: Error updating client site destinations for client ${client.clientId}:`,
error
);
});
Promise.all(newtJobs).catch((error) => {
logger.error(
`rebuildClientAssociations: Error updating Newt peers for client ${client.clientId}:`,
error
);
});
Promise.all(olmJobs).catch((error) => {
logger.error(
`rebuildClientAssociations: Error updating Olm peers for client ${client.clientId}:`,
error
);
});
await Promise.all(exitNodeJobs);
await Promise.all(newtJobs);
await Promise.all(olmJobs);
}
async function handleMessagesForClientResources(
@@ -1634,195 +1528,3 @@ async function handleMessagesForClientResources(
await Promise.all([...proxyJobs, ...olmJobs]);
}
export type ClientAssociationsCacheVerification = {
clientId: number;
consistent: boolean;
// What permissions say the cache should contain
expectedSiteResourceIds: number[];
expectedSiteIds: number[];
// What the cache currently contains
actualSiteResourceIds: number[];
actualSiteIds: number[];
// Diff
missingSiteResourceIds: number[]; // present in expected, missing from cache
extraSiteResourceIds: number[]; // present in cache, not in expected
missingSiteIds: number[];
extraSiteIds: number[];
};
// verifyClientAssociationsCache walks the same permission-derivation logic as
// rebuildClientAssociationsFromClient but does NOT modify the database. It
// returns the expected vs actual cache contents and a boolean indicating
// whether the cache is in sync with what permissions imply.
export async function verifyClientAssociationsCache(
client: Client,
trx: Transaction | typeof db = db
): Promise<ClientAssociationsCacheVerification> {
let newSiteResourceIds: number[] = [];
// 1. Direct client associations
const directSiteResources = await trx
.select({ siteResourceId: clientSiteResources.siteResourceId })
.from(clientSiteResources)
.innerJoin(
siteResources,
eq(siteResources.siteResourceId, clientSiteResources.siteResourceId)
)
.where(
and(
eq(clientSiteResources.clientId, client.clientId),
eq(siteResources.orgId, client.orgId)
)
);
newSiteResourceIds.push(
...directSiteResources.map((r) => r.siteResourceId)
);
// 2. User-based and role-based access (if client has a userId)
if (client.userId) {
const userSiteResourceIds = await trx
.select({ siteResourceId: userSiteResources.siteResourceId })
.from(userSiteResources)
.innerJoin(
siteResources,
eq(
siteResources.siteResourceId,
userSiteResources.siteResourceId
)
)
.where(
and(
eq(userSiteResources.userId, client.userId),
eq(siteResources.orgId, client.orgId)
)
);
newSiteResourceIds.push(
...userSiteResourceIds.map((r) => r.siteResourceId)
);
const roleIds = await trx
.select({ roleId: userOrgRoles.roleId })
.from(userOrgRoles)
.where(
and(
eq(userOrgRoles.userId, client.userId),
eq(userOrgRoles.orgId, client.orgId)
)
)
.then((rows) => rows.map((row) => row.roleId));
if (roleIds.length > 0) {
const roleSiteResourceIds = await trx
.select({ siteResourceId: roleSiteResources.siteResourceId })
.from(roleSiteResources)
.innerJoin(
siteResources,
eq(
siteResources.siteResourceId,
roleSiteResources.siteResourceId
)
)
.where(
and(
inArray(roleSiteResources.roleId, roleIds),
eq(siteResources.orgId, client.orgId)
)
);
newSiteResourceIds.push(
...roleSiteResourceIds.map((r) => r.siteResourceId)
);
}
}
newSiteResourceIds = Array.from(new Set(newSiteResourceIds));
const newSiteResources =
newSiteResourceIds.length > 0
? await trx
.select()
.from(siteResources)
.where(
inArray(siteResources.siteResourceId, newSiteResourceIds)
)
: [];
const networkIds = Array.from(
new Set(
newSiteResources
.map((sr) => sr.networkId)
.filter((id): id is number => id !== null)
)
);
const newSiteIds =
networkIds.length > 0
? await trx
.select({ siteId: siteNetworks.siteId })
.from(siteNetworks)
.where(inArray(siteNetworks.networkId, networkIds))
.then((rows) =>
Array.from(new Set(rows.map((r) => r.siteId)))
)
: [];
// Read the existing cache state
const existingResourceAssociations = await trx
.select({
siteResourceId: clientSiteResourcesAssociationsCache.siteResourceId
})
.from(clientSiteResourcesAssociationsCache)
.where(
eq(clientSiteResourcesAssociationsCache.clientId, client.clientId)
);
const existingSiteResourceIds = existingResourceAssociations.map(
(r) => r.siteResourceId
);
const existingSiteAssociations = await trx
.select({ siteId: clientSitesAssociationsCache.siteId })
.from(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
const existingSiteIds = existingSiteAssociations.map((s) => s.siteId);
const expectedSiteResourceSet = new Set(newSiteResourceIds);
const actualSiteResourceSet = new Set(existingSiteResourceIds);
const expectedSiteSet = new Set(newSiteIds);
const actualSiteSet = new Set(existingSiteIds);
const missingSiteResourceIds = newSiteResourceIds.filter(
(id) => !actualSiteResourceSet.has(id)
);
const extraSiteResourceIds = existingSiteResourceIds.filter(
(id) => !expectedSiteResourceSet.has(id)
);
const missingSiteIds = newSiteIds.filter((id) => !actualSiteSet.has(id));
const extraSiteIds = existingSiteIds.filter(
(id) => !expectedSiteSet.has(id)
);
const consistent =
missingSiteResourceIds.length === 0 &&
extraSiteResourceIds.length === 0 &&
missingSiteIds.length === 0 &&
extraSiteIds.length === 0;
return {
clientId: client.clientId,
consistent,
expectedSiteResourceIds: Array.from(expectedSiteResourceSet).sort(
(a, b) => a - b
),
expectedSiteIds: Array.from(expectedSiteSet).sort((a, b) => a - b),
actualSiteResourceIds: Array.from(actualSiteResourceSet).sort(
(a, b) => a - b
),
actualSiteIds: Array.from(actualSiteSet).sort((a, b) => a - b),
missingSiteResourceIds: missingSiteResourceIds.sort((a, b) => a - b),
extraSiteResourceIds: extraSiteResourceIds.sort((a, b) => a - b),
missingSiteIds: missingSiteIds.sort((a, b) => a - b),
extraSiteIds: extraSiteIds.sort((a, b) => a - b)
};
}

View File

@@ -1,7 +1,7 @@
import { z } from "zod";
import { db, logsDb, statusHistory } from "@server/db";
import { and, eq, gte, asc } from "drizzle-orm";
import { regionalCache as cache } from "#dynamic/lib/cache";
import cache from "@server/lib/cache";
const STATUS_HISTORY_CACHE_TTL = 60; // seconds
@@ -66,7 +66,7 @@ export async function invalidateStatusHistoryCache(
entityId: number
): Promise<void> {
const prefix = `statusHistory:${entityType}:${entityId}:`;
const keys = await cache.keysWithPrefix(prefix);
const keys = cache.keys().filter((k) => k.startsWith(prefix));
if (keys.length > 0) {
await cache.del(keys);
}

View File

@@ -13,7 +13,7 @@
import NodeCache from "node-cache";
import logger from "@server/logger";
import { redisManager, regionalRedisManager } from "@server/private/lib/redis";
import { redisManager } from "@server/private/lib/redis";
// Create local cache with maxKeys limit to prevent memory leaks
// With ~10k requests/day and 5min TTL, 10k keys should be more than sufficient
@@ -298,147 +298,3 @@ class AdaptiveCache {
// Export singleton instance
export const cache = new AdaptiveCache();
export default cache;
/**
* Regional adaptive cache backed by the in-cluster Redis instance.
* Falls back to a local NodeCache when the regional Redis is unavailable.
* Use this for data that is regional in nature (e.g. status history) so
* reads are served from the same cluster the user is hitting.
*/
const regionalLocalCache = new NodeCache({
stdTTL: 3600,
checkperiod: 120,
maxKeys: 10000
});
class RegionalAdaptiveCache {
private useRedis(): boolean {
return (
regionalRedisManager.isRedisEnabled() &&
regionalRedisManager.getHealthStatus().isHealthy
);
}
async set(key: string, value: any, ttl?: number): Promise<boolean> {
const effectiveTtl = ttl === 0 ? undefined : ttl;
const redisTtl = ttl === 0 ? undefined : (ttl ?? 3600);
if (this.useRedis()) {
try {
const serialized = JSON.stringify(value);
const success = await regionalRedisManager.set(
key,
serialized,
redisTtl
);
if (success) {
logger.debug(`[regional] Set key in Redis: ${key}`);
return true;
}
} catch (error) {
logger.error(
`[regional] Redis set error for key ${key}:`,
error
);
}
}
const success = regionalLocalCache.set(key, value, effectiveTtl || 0);
if (success) logger.debug(`[regional] Set key in local cache: ${key}`);
return success;
}
async get<T = any>(key: string): Promise<T | undefined> {
if (this.useRedis()) {
try {
const value = await regionalRedisManager.get(key);
if (value !== null) {
logger.debug(`[regional] Cache hit in Redis: ${key}`);
return JSON.parse(value) as T;
}
logger.debug(`[regional] Cache miss in Redis: ${key}`);
return undefined;
} catch (error) {
logger.error(
`[regional] Redis get error for key ${key}:`,
error
);
}
}
const value = regionalLocalCache.get<T>(key);
if (value !== undefined) {
logger.debug(`[regional] Cache hit in local cache: ${key}`);
} else {
logger.debug(`[regional] Cache miss in local cache: ${key}`);
}
return value;
}
async del(key: string | string[]): Promise<number> {
const keys = Array.isArray(key) ? key : [key];
let deletedCount = 0;
if (this.useRedis()) {
try {
for (const k of keys) {
const success = await regionalRedisManager.del(k);
if (success) {
deletedCount++;
logger.debug(`[regional] Deleted key from Redis: ${k}`);
}
}
if (deletedCount === keys.length) return deletedCount;
deletedCount = 0;
} catch (error) {
logger.error(`[regional] Redis del error:`, error);
deletedCount = 0;
}
}
for (const k of keys) {
const count = regionalLocalCache.del(k);
if (count > 0) {
deletedCount++;
logger.debug(`[regional] Deleted key from local cache: ${k}`);
}
}
return deletedCount;
}
async has(key: string): Promise<boolean> {
if (this.useRedis()) {
try {
const value = await regionalRedisManager.get(key);
return value !== null;
} catch (error) {
logger.error(
`[regional] Redis has error for key ${key}:`,
error
);
}
}
return regionalLocalCache.has(key);
}
/**
* Returns keys matching the given prefix from whichever backend is active.
* Redis uses a KEYS scan; local cache filters in-memory keys.
*/
async keysWithPrefix(prefix: string): Promise<string[]> {
if (this.useRedis()) {
try {
return await regionalRedisManager.keys(`${prefix}*`);
} catch (error) {
logger.error(`[regional] Redis keys error:`, error);
}
}
return regionalLocalCache.keys().filter((k) => k.startsWith(prefix));
}
getCurrentBackend(): "redis" | "local" {
return this.useRedis() ? "redis" : "local";
}
}
export const regionalCache = new RegionalAdaptiveCache();

View File

@@ -73,25 +73,6 @@ export const privateConfigSchema = z
.object({
rejectUnauthorized: z.boolean().optional().default(true)
})
.optional(),
regional_redis: z
.object({
host: z.string(),
port: portSchema,
password: z
.string()
.optional()
.transform(getEnvOrYaml("REGIONAL_REDIS_PASSWORD")),
db: z.int().nonnegative().optional().default(0),
tls: z
.object({
rejectUnauthorized: z
.boolean()
.optional()
.default(true)
})
.optional()
})
.optional()
})
.optional(),

View File

@@ -109,14 +109,14 @@ class RedisManager {
password: redisConfig.password,
db: redisConfig.db
};
// Enable TLS if configured (required for AWS ElastiCache in-transit encryption)
if (redisConfig.tls) {
opts.tls = {
rejectUnauthorized: redisConfig.tls.rejectUnauthorized ?? true
};
}
return opts;
}
@@ -135,14 +135,14 @@ class RedisManager {
password: replica.password,
db: replica.db || redisConfig.db
};
// Enable TLS if configured (required for AWS ElastiCache in-transit encryption)
if (redisConfig.tls) {
opts.tls = {
rejectUnauthorized: redisConfig.tls.rejectUnauthorized ?? true
};
}
return opts;
}
@@ -855,163 +855,3 @@ class RedisManager {
export const redisManager = new RedisManager();
export const redis = redisManager.getClient();
export default redisManager;
/**
* Lightweight Redis manager for the regional (in-cluster) Redis instance.
* Connects only when `redis.regional_redis` is present in the private config
* and `flags.enable_redis` is true. No pub/sub — designed for low-latency
* caching of regionally-scoped data.
*/
class RegionalRedisManager {
private writeClient: Redis | null = null;
private readClient: Redis | null = null;
private isEnabled: boolean = false;
private isHealthy: boolean = false;
private connectionTimeout: number = 5000;
private commandTimeout: number = 5000;
constructor() {
if (build === "oss") return;
const cfg = privateConfig.getRawPrivateConfig();
if (!cfg.flags.enable_redis || !cfg.redis?.regional_redis) return;
this.isEnabled = true;
this.initializeClients();
}
private getConfig(): RedisOptions {
const r = privateConfig.getRawPrivateConfig().redis!.regional_redis!;
const opts: RedisOptions = {
host: r.host,
port: r.port,
password: r.password,
db: r.db
};
if (r.tls) {
opts.tls = { rejectUnauthorized: r.tls.rejectUnauthorized ?? true };
}
return opts;
}
private initializeClients(): void {
const cfg = this.getConfig();
const baseOpts = {
...cfg,
enableReadyCheck: false,
maxRetriesPerRequest: 3,
keepAlive: 10000,
connectTimeout: this.connectionTimeout,
commandTimeout: this.commandTimeout
};
try {
this.writeClient = new Redis(baseOpts);
// redis-1 (replica) handles reads; fall back to primary if not resolvable
this.readClient = new Redis({
...baseOpts,
host: cfg.host!.replace(/^(.*?)(\.\S+)$/, (_, h, rest) => {
// Derive replica hostname from the headless service pattern:
// redis.redis.svc.cluster.local -> redis-1.redis-headless.redis.svc.cluster.local
// If it doesn't look like a k8s service, just use the same host
return h + rest;
})
});
// For simplicity use same host for both; callers can always read from primary
// The real replica routing is handled by the StatefulSet headless service
this.readClient = this.writeClient;
this.writeClient.on("ready", () => {
logger.info("Regional Redis client ready");
this.isHealthy = true;
});
this.writeClient.on("error", (err) => {
logger.error("Regional Redis client error:", err);
this.isHealthy = false;
});
this.writeClient.on("reconnecting", () => {
logger.info("Regional Redis client reconnecting...");
this.isHealthy = false;
});
logger.info("Regional Redis client initialized");
} catch (error) {
logger.error("Failed to initialize regional Redis client:", error);
this.isEnabled = false;
}
}
public isRedisEnabled(): boolean {
return this.isEnabled && this.writeClient !== null && this.isHealthy;
}
public getHealthStatus() {
return { isEnabled: this.isEnabled, isHealthy: this.isHealthy };
}
public async set(
key: string,
value: string,
ttl?: number
): Promise<boolean> {
if (!this.isRedisEnabled() || !this.writeClient) return false;
try {
if (ttl) {
await this.writeClient.setex(key, ttl, value);
} else {
await this.writeClient.set(key, value);
}
return true;
} catch (error) {
logger.error("Regional Redis SET error:", error);
return false;
}
}
public async get(key: string): Promise<string | null> {
if (!this.isRedisEnabled() || !this.readClient) return null;
try {
return await this.readClient.get(key);
} catch (error) {
logger.error("Regional Redis GET error:", error);
return null;
}
}
public async del(key: string): Promise<boolean> {
if (!this.isRedisEnabled() || !this.writeClient) return false;
try {
await this.writeClient.del(key);
return true;
} catch (error) {
logger.error("Regional Redis DEL error:", error);
return false;
}
}
public async keys(pattern: string): Promise<string[]> {
if (!this.isRedisEnabled() || !this.readClient) return [];
try {
return await this.readClient.keys(pattern);
} catch (error) {
logger.error("Regional Redis KEYS error:", error);
return [];
}
}
public async disconnect(): Promise<void> {
try {
if (this.writeClient) {
await this.writeClient.quit();
this.writeClient = null;
}
this.readClient = null;
logger.info("Regional Redis client disconnected");
} catch (error) {
logger.error("Error disconnecting regional Redis client:", error);
}
}
}
export const regionalRedisManager = new RegionalRedisManager();

View File

@@ -32,7 +32,6 @@ import * as eventStreamingDestination from "#private/routers/eventStreamingDesti
import * as alertRule from "#private/routers/alertRule";
import * as healthChecks from "#private/routers/healthChecks";
import * as labels from "#private/routers/labels";
import * as client from "@server/routers/client";
import {
verifyOrgAccess,
@@ -830,15 +829,3 @@ authenticated.get(
verifyUserHasAction(ActionsEnum.getTarget),
healthChecks.getHealthCheckStatusHistory
);
authenticated.get(
"/client/:clientId/verify-associations-cache",
verifyClientAccess,
client.verifyClientAssociationsCache
);
authenticated.post(
"/client/:clientId/rebuild-associations-cache",
verifyClientAccess,
client.rebuildClientAssociationsCacheRoute
);

View File

@@ -26,6 +26,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, InferInsertModel } from "drizzle-orm";
import { build } from "@server/build";
import { validateLocalPath } from "@app/lib/validateLocalPath";
import config from "#private/lib/config";
const paramsSchema = z.strictObject({
@@ -34,9 +35,78 @@ const paramsSchema = z.strictObject({
const bodySchema = z.strictObject({
logoUrl: z
.string()
.optional()
.transform((val) => (val === "" ? null : val)),
.union([
z.literal(""),
z
.string()
.superRefine(async (urlOrPath, ctx) => {
const parseResult = z.url().safeParse(urlOrPath);
if (!parseResult.success) {
if (build !== "enterprise") {
ctx.addIssue({
code: "custom",
message: "Must be a valid URL"
});
return;
} else {
try {
validateLocalPath(urlOrPath);
} catch (error) {
ctx.addIssue({
code: "custom",
message: "Must be either a valid image URL or a valid pathname starting with `/` and not containing query parameters, `..` or `*`"
});
} finally {
return;
}
}
}
try {
const response = await fetch(urlOrPath, {
method: "HEAD"
}).catch(() => {
// If HEAD fails (CORS or method not allowed), try GET
return fetch(urlOrPath, { method: "GET" });
});
if (response.status !== 200) {
ctx.addIssue({
code: "custom",
message: `Failed to load image. Please check that the URL is accessible.`
});
return;
}
const contentType =
response.headers.get("content-type") ?? "";
if (!contentType.startsWith("image/")) {
ctx.addIssue({
code: "custom",
message: `URL does not point to an image. Please provide a URL to an image file (e.g., .png, .jpg, .svg).`
});
return;
}
} catch (error) {
let errorMessage =
"Unable to verify image URL. Please check that the URL is accessible and points to an image file.";
if (error instanceof TypeError && error.message.includes("fetch")) {
errorMessage =
"Network error: Unable to reach the URL. Please check your internet connection and verify the URL is correct.";
} else if (error instanceof Error) {
errorMessage = `Error verifying URL: ${error.message}`;
}
ctx.addIssue({
code: "custom",
message: errorMessage
});
}
})
])
.transform((val) => (val === "" ? null : val))
.nullish(),
logoWidth: z.coerce.number<number>().min(1),
logoHeight: z.coerce.number<number>().min(1),
resourceTitle: z.string(),

View File

@@ -0,0 +1,202 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025-2026 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import axios from "axios";
import { db, exitNodes, newts, sites } from "@server/db";
import { eq } from "drizzle-orm";
import logger from "@server/logger";
import redisManager from "#private/lib/redis";
import { sendToClient } from "#private/routers/ws";
const INITIAL_DELAY_MS = 15 * 1000; // 15 seconds before first check
const CHECK_INTERVAL_MS = 10 * 1000; // Check every 10 seconds
const MAX_DURATION_MS = 5 * 60 * 1000; // Give up after 5 minutes
const REDIS_PENDING_SET = "exit-node-reconnect-pending";
const REDIS_HASH_PREFIX = "exit-node-reconnect:";
interface PendingReconnect {
startTime: number;
reachableAt: string;
}
// In-memory tracking for this node
const pendingReconnects = new Map<number, PendingReconnect>();
let schedulerInterval: NodeJS.Timeout | null = null;
/**
* Schedules a reconnect check for newts connected to the given exit node.
* Called when an exit node transitions from offline to online.
*/
export async function scheduleExitNodeReconnect(
exitNodeId: number,
reachableAt: string
): Promise<void> {
logger.info(
`Scheduling newt reconnect for exit node ${exitNodeId} (reachableAt: ${reachableAt})`
);
const entry: PendingReconnect = {
startTime: Date.now(),
reachableAt
};
pendingReconnects.set(exitNodeId, entry);
// Store in Redis if available for cross-node coordination
if (redisManager.isRedisEnabled()) {
await redisManager.sadd(REDIS_PENDING_SET, exitNodeId.toString());
await redisManager.hset(
`${REDIS_HASH_PREFIX}${exitNodeId}`,
"startTime",
entry.startTime.toString()
);
await redisManager.hset(
`${REDIS_HASH_PREFIX}${exitNodeId}`,
"reachableAt",
reachableAt
);
}
}
/**
* Starts the background interval that checks pending exit node reconnects.
*/
export function startExitNodeReconnectScheduler(): void {
if (schedulerInterval) {
return;
}
schedulerInterval = setInterval(async () => {
try {
await processPendingReconnects();
} catch (error) {
logger.error("Error in exit node reconnect scheduler", { error });
}
}, CHECK_INTERVAL_MS);
logger.debug("Started exit node reconnect scheduler");
}
async function processPendingReconnects(): Promise<void> {
// Merge in-memory and Redis-tracked pending reconnects
const toProcess = new Map(pendingReconnects);
if (redisManager.isRedisEnabled()) {
const redisIds = await redisManager.smembers(REDIS_PENDING_SET);
for (const idStr of redisIds) {
const id = parseInt(idStr, 10);
if (!toProcess.has(id)) {
const startTimeStr = await redisManager.hget(
`${REDIS_HASH_PREFIX}${id}`,
"startTime"
);
const reachableAt = await redisManager.hget(
`${REDIS_HASH_PREFIX}${id}`,
"reachableAt"
);
if (startTimeStr && reachableAt) {
toProcess.set(id, {
startTime: parseInt(startTimeStr, 10),
reachableAt
});
}
}
}
}
const now = Date.now();
for (const [exitNodeId, entry] of toProcess) {
const elapsed = now - entry.startTime;
// Give up after max duration
if (elapsed >= MAX_DURATION_MS) {
logger.warn(
`Exit node reconnect check timed out for exit node ${exitNodeId} after 5 minutes`
);
await removePending(exitNodeId);
continue;
}
// Respect initial delay
if (elapsed < INITIAL_DELAY_MS) {
continue;
}
// Check if the exit node HTTP endpoint is reachable
const pingUrl = `${entry.reachableAt}/ping`;
try {
await axios.get(pingUrl, { timeout: 5000 });
} catch {
logger.debug(
`Exit node ${exitNodeId} not yet reachable at ${pingUrl}`
);
continue;
}
// Node is reachable — send reconnect to all connected newts
logger.info(
`Exit node ${exitNodeId} is reachable. Sending newt/wg/reconnect to connected newts.`
);
await sendReconnectToNewts(exitNodeId);
await removePending(exitNodeId);
}
}
async function sendReconnectToNewts(exitNodeId: number): Promise<void> {
try {
const connectedNewts = await db
.select({ newtId: newts.newtId })
.from(newts)
.innerJoin(sites, eq(newts.siteId, sites.siteId))
.where(eq(sites.exitNodeId, exitNodeId));
if (connectedNewts.length === 0) {
logger.debug(
`No newts found for exit node ${exitNodeId}, nothing to reconnect`
);
return;
}
logger.info(
`Sending newt/wg/reconnect to ${connectedNewts.length} newt(s) for exit node ${exitNodeId}`
);
const reconnectMessage = {
type: "newt/wg/reconnect",
data: {}
};
await Promise.allSettled(
connectedNewts.map(({ newtId }) =>
sendToClient(newtId, reconnectMessage)
)
);
} catch (error) {
logger.error(
`Failed to send reconnect messages for exit node ${exitNodeId}`,
{ error }
);
}
}
async function removePending(exitNodeId: number): Promise<void> {
pendingReconnects.delete(exitNodeId);
if (redisManager.isRedisEnabled()) {
await redisManager.srem(REDIS_PENDING_SET, exitNodeId.toString());
await redisManager.del(`${REDIS_HASH_PREFIX}${exitNodeId}`);
}
}

View File

@@ -16,6 +16,7 @@ import { MessageHandler } from "@server/routers/ws";
import { RemoteExitNode } from "@server/db";
import { eq } from "drizzle-orm";
import logger from "@server/logger";
import { scheduleExitNodeReconnect } from "./exitNodeReconnectScheduler";
/**
* Handles ping messages from clients and responds with pong
@@ -37,6 +38,13 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (
}
try {
// Fetch the current state before updating so we can detect the offline→online transition
const [currentExitNode] = await db
.select({ online: exitNodes.online, reachableAt: exitNodes.reachableAt })
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId))
.limit(1);
// Update the exit node's last ping timestamp
await db
.update(exitNodes)
@@ -45,6 +53,16 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (
online: true
})
.where(eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId));
// If the exit node was offline and is now coming online, schedule newt reconnects
if (currentExitNode && !currentExitNode.online && currentExitNode.reachableAt) {
scheduleExitNodeReconnect(
remoteExitNode.exitNodeId,
currentExitNode.reachableAt
).catch((error) => {
logger.error("Failed to schedule exit node reconnect", { error });
});
}
} catch (error) {
logger.error("Error handling ping message", { error });
}

View File

@@ -22,3 +22,4 @@ export * from "./listRemoteExitNodes";
export * from "./pickRemoteExitNodeDefaults";
export * from "./quickStartRemoteExitNode";
export * from "./offlineChecker";
export * from "./exitNodeReconnectScheduler";

View File

@@ -14,7 +14,8 @@
import {
handleRemoteExitNodeRegisterMessage,
handleRemoteExitNodePingMessage,
startRemoteExitNodeOfflineChecker
startRemoteExitNodeOfflineChecker,
startExitNodeReconnectScheduler
} from "#private/routers/remoteExitNode";
import { MessageHandler } from "@server/routers/ws";
import { build } from "@server/build";
@@ -29,4 +30,5 @@ export const messageHandlers: Record<string, MessageHandler> = {
if (build != "saas") {
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
startExitNodeReconnectScheduler(); // check pending exit node reconnects and notify newts
}

View File

@@ -10,5 +10,3 @@ export * from "./listUserDevices";
export * from "./updateClient";
export * from "./getClient";
export * from "./createUserClient";
export * from "./verifyClientAssociationsCache";
export * from "./rebuildClientAssociationsCacheRoute";

View File

@@ -1,81 +0,0 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
const paramsSchema = z.strictObject({
clientId: z.string().transform(Number).pipe(z.int().positive())
});
registry.registerPath({
method: "post",
path: "/client/{clientId}/rebuild-associations-cache",
description:
"Rebuild the client's site/site-resource association cache based on current permissions.",
tags: [OpenAPITags.Client],
request: {
params: paramsSchema
},
responses: {}
});
export async function rebuildClientAssociationsCacheRoute(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
await rebuildClientAssociationsFromClient(client);
return response(res, {
data: null,
success: true,
error: false,
message: "Client association cache rebuilt successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to rebuild client association cache"
)
);
}
}

View File

@@ -1,83 +0,0 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { verifyClientAssociationsCache as verifyClientAssociationsCacheLib } from "@server/lib/rebuildClientAssociations";
const paramsSchema = z.strictObject({
clientId: z.string().transform(Number).pipe(z.int().positive())
});
registry.registerPath({
method: "get",
path: "/client/{clientId}/verify-associations-cache",
description:
"Read-only check of whether the client's site/site-resource association cache matches what the current permissions imply.",
tags: [OpenAPITags.Client],
request: {
params: paramsSchema
},
responses: {}
});
export async function verifyClientAssociationsCache(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
const report = await verifyClientAssociationsCacheLib(client);
return response(res, {
data: report,
success: true,
error: false,
message: report.consistent
? "Client association cache is consistent"
: "Client association cache is INCONSISTENT",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to verify client association cache"
)
);
}
}

View File

@@ -153,65 +153,6 @@ export default function GeneralPage() {
const [approvalId, setApprovalId] = useState<number | null>(null);
const [isRefreshing, setIsRefreshing] = useState(false);
const [, startTransition] = useTransition();
const [cacheCheck, setCacheCheck] = useState<null | {
consistent: boolean;
missingSiteResourceIds: number[];
extraSiteResourceIds: number[];
missingSiteIds: number[];
extraSiteIds: number[];
expectedSiteResourceIds: number[];
actualSiteResourceIds: number[];
expectedSiteIds: number[];
actualSiteIds: number[];
}>(null);
const [isCheckingCache, setIsCheckingCache] = useState(false);
const [isRebuildingCache, setIsRebuildingCache] = useState(false);
const handleRebuildCache = async () => {
if (!client.clientId) return;
setIsRebuildingCache(true);
try {
await api.post(
`/client/${client.clientId}/rebuild-associations-cache`
);
// Re-verify after rebuild so the result refreshes
const res = await api.get(
`/client/${client.clientId}/verify-associations-cache`
);
setCacheCheck(res.data.data);
toast({
title: "Cache rebuilt",
description: "Association cache rebuilt successfully."
});
} catch (e) {
toast({
variant: "destructive",
title: "Rebuild failed",
description: formatAxiosError(e, "Failed to rebuild cache")
});
} finally {
setIsRebuildingCache(false);
}
};
const handleVerifyCache = async () => {
if (!client.clientId) return;
setIsCheckingCache(true);
try {
const res = await api.get(
`/client/${client.clientId}/verify-associations-cache`
);
setCacheCheck(res.data.data);
} catch (e) {
toast({
variant: "destructive",
title: "Cache check failed",
description: formatAxiosError(e, "Failed to verify cache")
});
} finally {
setIsCheckingCache(false);
}
};
const { env } = useEnvContext();
const showApprovalFeatures =
@@ -903,75 +844,6 @@ export default function GeneralPage() {
</SettingsSectionBody>
</SettingsSection>
)}
{/* Hidden cache verification — subtle button, dev/admin diagnostic */}
<div className="mt-8 flex flex-col gap-2 items-start opacity-30 hover:opacity-100 transition-opacity">
<button
type="button"
onClick={handleVerifyCache}
disabled={isCheckingCache}
className="text-xs text-muted-foreground underline disabled:opacity-50"
title="Verify the client's site association cache against current permissions (read-only)"
>
{isCheckingCache
? "Checking cache…"
: "Verify association cache"}
</button>
{cacheCheck && (
<div
className={
"text-xs rounded border px-2 py-1 " +
(cacheCheck.consistent
? "border-green-600 text-green-700"
: "border-red-600 text-red-700")
}
>
{cacheCheck.consistent ? (
<span className="flex items-center gap-1">
<CheckCircle2 className="h-3 w-3" />
Cache is consistent
</span>
) : (
<div className="space-y-2">
<div className="flex items-center gap-1 font-semibold">
<XCircle className="h-3 w-3" />
Cache is INCONSISTENT
</div>
<div>
Missing site resources: [
{cacheCheck.missingSiteResourceIds.join(
", "
)}
]
</div>
<div>
Extra site resources: [
{cacheCheck.extraSiteResourceIds.join(", ")}
]
</div>
<div>
Missing sites: [
{cacheCheck.missingSiteIds.join(", ")}]
</div>
<div>
Extra sites: [
{cacheCheck.extraSiteIds.join(", ")}]
</div>
<button
type="button"
onClick={handleRebuildCache}
disabled={isRebuildingCache}
className="mt-1 text-xs underline font-semibold disabled:opacity-50"
>
{isRebuildingCache
? "Rebuilding…"
: "Rebuild cache now"}
</button>
</div>
)}
</div>
)}
</div>
</SettingsContainer>
);
}

View File

@@ -44,11 +44,77 @@ export type AuthPageCustomizationProps = {
};
const AuthPageFormSchema = z.object({
logoUrl: z
.string()
.optional()
.transform((val) => (val === "" ? undefined : val)),
logoUrl: z.union([
z.literal(""),
z.string().superRefine(async (urlOrPath, ctx) => {
const parseResult = z.url().safeParse(urlOrPath);
if (!parseResult.success) {
if (build !== "enterprise") {
ctx.addIssue({
code: "custom",
message: "Must be a valid URL"
});
return;
} else {
try {
validateLocalPath(urlOrPath);
} catch (error) {
ctx.addIssue({
code: "custom",
message:
"Must be either a valid image URL or a valid pathname starting with `/` and not containing query parameters, `..` or `*`"
});
} finally {
return;
}
}
}
try {
const response = await fetch(urlOrPath, {
method: "HEAD"
}).catch(() => {
// If HEAD fails (CORS or method not allowed), try GET
return fetch(urlOrPath, { method: "GET" });
});
if (response.status !== 200) {
ctx.addIssue({
code: "custom",
message: `Failed to load image. Please check that the URL is accessible.`
});
return;
}
const contentType = response.headers.get("content-type") ?? "";
if (!contentType.startsWith("image/")) {
ctx.addIssue({
code: "custom",
message: `URL does not point to an image. Please provide a URL to an image file (e.g., .png, .jpg, .svg).`
});
return;
}
} catch (error) {
let errorMessage =
"Unable to verify image URL. Please check that the URL is accessible and points to an image file.";
if (
error instanceof TypeError &&
error.message.includes("fetch")
) {
errorMessage =
"Network error: Unable to reach the URL. Please check your internet connection and verify the URL is correct.";
} else if (error instanceof Error) {
errorMessage = `Error verifying URL: ${error.message}`;
}
ctx.addIssue({
code: "custom",
message: errorMessage
});
}
})
]),
logoWidth: z.coerce.number<number>().min(1),
logoHeight: z.coerce.number<number>().min(1),
orgTitle: z.string().optional(),

View File

@@ -318,28 +318,12 @@ export default function DeviceLoginForm({
<FormControl>
<div className="flex justify-center">
<InputOTP
maxLength={8}
maxLength={9}
pattern={REGEXP_ONLY_DIGITS_AND_CHARS}
{...field}
value={field.value
.replace(/-/g, "")
.toUpperCase()}
onPaste={(event) => {
event.preventDefault();
const pastedText =
event.clipboardData.getData(
"text"
);
const cleanedValue =
pastedText
.replace(
/[^a-zA-Z0-9]/g,
""
)
.toUpperCase()
.slice(0, 8);
field.onChange(cleanedValue);
}}
onChange={(value) => {
// Strip hyphens and convert to uppercase
const cleanedValue = value

View File

@@ -46,20 +46,6 @@ function toSshSudoMode(value: string | null | undefined): SshSudoMode {
return "none";
}
function hasOnlyAbsoluteSudoCommands(value: string | undefined): boolean {
if (!value?.trim()) return true;
const commands = value
.split(",")
.map((command) => command.trim())
.filter(Boolean);
return commands.every((command) => {
const executable = command.split(/\s+/)[0];
return executable.startsWith("/");
});
}
export type RoleFormValues = {
name: string;
description?: string;
@@ -88,33 +74,19 @@ export function RoleForm({
const { isPaidUser } = usePaidStatus();
const { env } = useEnvContext();
const formSchema = z
.object({
name: z
.string({ message: t("nameRequired") })
.min(1)
.max(32),
description: z.string().max(255).optional(),
requireDeviceApproval: z.boolean().optional(),
allowSsh: z.boolean().optional(),
sshSudoMode: z.enum(SSH_SUDO_MODE_VALUES),
sshSudoCommands: z.string().optional(),
sshCreateHomeDir: z.boolean().optional(),
sshUnixGroups: z.string().optional()
})
.superRefine((values, ctx) => {
if (
values.sshSudoMode === "commands" &&
!hasOnlyAbsoluteSudoCommands(values.sshSudoCommands)
) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ["sshSudoCommands"],
message:
"Each sudo command must start with an absolute path (for example, /usr/bin/systemctl)."
});
}
});
const formSchema = z.object({
name: z
.string({ message: t("nameRequired") })
.min(1)
.max(32),
description: z.string().max(255).optional(),
requireDeviceApproval: z.boolean().optional(),
allowSsh: z.boolean().optional(),
sshSudoMode: z.enum(SSH_SUDO_MODE_VALUES),
sshSudoCommands: z.string().optional(),
sshCreateHomeDir: z.boolean().optional(),
sshUnixGroups: z.string().optional()
});
const defaultValues: RoleFormValues = role
? {
@@ -324,9 +296,7 @@ export function RoleForm({
control={form.control}
name="allowSsh"
render={({ field }) => {
const allowSshOptions: OptionSelectOption<
"allow" | "disallow"
>[] = [
const allowSshOptions: OptionSelectOption<"allow" | "disallow">[] = [
{
value: "allow",
label: t("roleAllowSshAllow")
@@ -341,9 +311,7 @@ export function RoleForm({
<FormLabel>
{t("roleAllowSsh")}
</FormLabel>
<OptionSelect<
"allow" | "disallow"
>
<OptionSelect<"allow" | "disallow">
options={allowSshOptions}
value={
sshDisabled
@@ -354,9 +322,7 @@ export function RoleForm({
}
onChange={(v) => {
if (sshDisabled) return;
field.onChange(
v === "allow"
);
field.onChange(v === "allow");
}}
cols={2}
disabled={sshDisabled}