From 78ad2d17c7b0149751b09a37e5ea7bcb62e43350 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 14 May 2026 12:25:05 -0700 Subject: [PATCH] Optimize building aliases in jit mode --- server/routers/olm/buildConfiguration.ts | 155 ++++++++++-------- .../routers/olm/handleOlmRegisterMessage.ts | 42 ++--- 2 files changed, 112 insertions(+), 85 deletions(-) diff --git a/server/routers/olm/buildConfiguration.ts b/server/routers/olm/buildConfiguration.ts index 640031bca..41bb6d60d 100644 --- a/server/routers/olm/buildConfiguration.ts +++ b/server/routers/olm/buildConfiguration.ts @@ -5,6 +5,7 @@ import { db, exitNodes, networks, + SiteResource, siteNetworks, siteResources, sites @@ -15,7 +16,7 @@ import { generateRemoteSubnets } from "@server/lib/ip"; import logger from "@server/logger"; -import { and, eq } from "drizzle-orm"; +import { eq, inArray } from "drizzle-orm"; import { addPeer, deletePeer } from "../newt/peers"; import config from "@server/lib/config"; @@ -46,49 +47,79 @@ export async function buildSiteConfigurationForOlmClient( ) .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); + if (sitesData.length === 0) { + return siteConfigurations; + } + + // Batch-fetch every site resource this client has access to across ALL sites + // in a single query, then group by siteId in memory. This avoids issuing one + // query per site (which would be N round-trips for N sites). + const allClientSiteResources = await db + .select({ + siteResource: siteResources, + siteId: siteNetworks.siteId + }) + .from(siteResources) + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + siteResources.siteResourceId, + clientSiteResourcesAssociationsCache.siteResourceId + ) + ) + .innerJoin(networks, eq(siteResources.networkId, networks.networkId)) + .innerJoin(siteNetworks, eq(networks.networkId, siteNetworks.networkId)) + .where( + eq(clientSiteResourcesAssociationsCache.clientId, client.clientId) + ); + + const siteResourcesBySiteId = new Map(); + for (const row of allClientSiteResources) { + const arr = siteResourcesBySiteId.get(row.siteId); + if (arr) { + arr.push(row.siteResource); + } else { + siteResourcesBySiteId.set(row.siteId, [row.siteResource]); + } + } + + // Batch-fetch exit nodes for all sites in one query (only needed in relay mode). + const exitNodesById = new Map(); + if (!jitMode && relay) { + const exitNodeIds = Array.from( + new Set( + sitesData + .map(({ sites: s }) => s.exitNodeId) + .filter((id): id is number => id != null) + ) + ); + if (exitNodeIds.length > 0) { + const nodes = await db + .select() + .from(exitNodes) + .where(inArray(exitNodes.exitNodeId, exitNodeIds)); + for (const n of nodes) { + exitNodesById.set(n.exitNodeId, n); + } + } + } + + const clientsStartPort = config.getRawConfig().gerbil.clients_start_port; + const peerOps: Promise[] = []; + // Process each site for (const { sites: site, clientSitesAssociationsCache: association } of sitesData) { - const allSiteResources = await db // only get the site resources that this client has access to - .select() - .from(siteResources) - .innerJoin( - clientSiteResourcesAssociationsCache, - eq( - siteResources.siteResourceId, - clientSiteResourcesAssociationsCache.siteResourceId - ) - ) - .innerJoin( - networks, - eq(siteResources.networkId, networks.networkId) - ) - .innerJoin( - siteNetworks, - eq(networks.networkId, siteNetworks.networkId) - ) - .where( - and( - eq(siteNetworks.siteId, site.siteId), - eq( - clientSiteResourcesAssociationsCache.clientId, - client.clientId - ) - ) - ); + const allSiteResources = siteResourcesBySiteId.get(site.siteId) ?? []; if (jitMode) { // Add site configuration to the array siteConfigurations.push({ siteId: site.siteId, - // remoteSubnets: generateRemoteSubnets( - // allSiteResources.map(({ siteResources }) => siteResources) - // ), - aliases: generateAliasConfig( - allSiteResources.map(({ siteResources }) => siteResources) - ) + // remoteSubnets: generateRemoteSubnets(allSiteResources), + aliases: generateAliasConfig(allSiteResources) }); continue; } @@ -126,7 +157,7 @@ export async function buildSiteConfigurationForOlmClient( logger.info( `Public key mismatch. Deleting old peer from site ${site.siteId}...` ); - await deletePeer(site.siteId, client.pubKey!); + peerOps.push(deletePeer(site.siteId, client.pubKey!)); } if (!site.subnet) { @@ -134,27 +165,19 @@ export async function buildSiteConfigurationForOlmClient( continue; } - const [clientSite] = await db - .select() - .from(clientSitesAssociationsCache) - .where( - and( - eq(clientSitesAssociationsCache.clientId, client.clientId), - eq(clientSitesAssociationsCache.siteId, site.siteId) - ) - ) - .limit(1); - - // Add the peer to the exit node for this site - if (clientSite.endpoint && publicKey) { + // Add the peer to the exit node for this site. The endpoint comes from + // the already-joined association row above, so no extra query needed. + if (association.endpoint && publicKey) { logger.info( - `Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}` + `Adding peer ${publicKey} to site ${site.siteId} with endpoint ${association.endpoint}` + ); + peerOps.push( + addPeer(site.siteId, { + publicKey: publicKey, + allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client + endpoint: relay ? "" : association.endpoint + }) ); - await addPeer(site.siteId, { - publicKey: publicKey, - allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client - endpoint: relay ? "" : clientSite.endpoint - }); } else { logger.warn( `Client ${client.clientId} has no endpoint, skipping peer addition` @@ -163,16 +186,12 @@ export async function buildSiteConfigurationForOlmClient( let relayEndpoint: string | undefined = undefined; if (relay) { - const [exitNode] = await db - .select() - .from(exitNodes) - .where(eq(exitNodes.exitNodeId, site.exitNodeId)) - .limit(1); + const exitNode = exitNodesById.get(site.exitNodeId); if (!exitNode) { logger.warn(`Exit node not found for site ${site.siteId}`); continue; } - relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`; + relayEndpoint = `${exitNode.endpoint}:${clientsStartPort}`; } // Add site configuration to the array @@ -184,12 +203,16 @@ export async function buildSiteConfigurationForOlmClient( publicKey: site.publicKey, serverIP: site.address, serverPort: site.listenPort, - remoteSubnets: generateRemoteSubnets( - allSiteResources.map(({ siteResources }) => siteResources) - ), - aliases: generateAliasConfig( - allSiteResources.map(({ siteResources }) => siteResources) - ) + remoteSubnets: generateRemoteSubnets(allSiteResources), + aliases: generateAliasConfig(allSiteResources) + }); + } + + // Run all peer add/delete operations concurrently rather than serially per + // site, so total time is bounded by the slowest call instead of the sum. + if (peerOps.length > 0) { + Promise.allSettled(peerOps).catch((err) => { + logger.error("Error processing peer operations: ", err); }); } diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 073e86f7b..7735e4d30 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -1,4 +1,4 @@ -import { db, orgs } from "@server/db"; +import { db, orgs, primaryDb } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; import { clients, @@ -81,7 +81,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .where(eq(olms.olmId, olm.olmId)); } - const [client] = await db + const [client] = await primaryDb // read from the primary here so there is no latency with the last update on the holepunch .select() .from(clients) .where(eq(clients.clientId, olm.clientId)) @@ -98,7 +98,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { if (client.blocked) { logger.debug( `[handleOlmRegisterMessage] Client ${client.clientId} is blocked. Ignoring register.`, - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId); return; @@ -107,7 +107,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { if (client.approvalState == "pending") { logger.debug( `[handleOlmRegisterMessage] Client ${client.clientId} approval is pending. Ignoring register.`, - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId); return; @@ -136,7 +136,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { if (!org) { logger.warn("[handleOlmRegisterMessage] Org not found", { - orgId: client.orgId + orgId: client.orgId, + clientId: client.clientId }); sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId); return; @@ -145,7 +146,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { if (orgId) { if (!olm.userId) { logger.warn("[handleOlmRegisterMessage] Olm has no user ID", { - orgId: client.orgId + orgId: client.orgId, + clientId: client.clientId }); sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId); return; @@ -156,7 +158,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { if (!userSession || !user) { logger.warn( "[handleOlmRegisterMessage] Invalid user session for olm register", - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId); return; @@ -164,7 +166,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { if (user.userId !== olm.userId) { logger.warn( "[handleOlmRegisterMessage] User ID mismatch for olm register", - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId); return; @@ -182,13 +184,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { logger.debug("[handleOlmRegisterMessage] Policy check result", { orgId: client.orgId, + clientId: client.clientId, policyCheck }); if (policyCheck?.error) { logger.error( `[handleOlmRegisterMessage] Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}`, - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId); return; @@ -197,7 +200,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { if (policyCheck.policies?.passwordAge?.compliant === false) { logger.warn( `[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant password age for org ${orgId}`, - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); sendOlmError( OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED, @@ -209,7 +212,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { ) { logger.warn( `[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant session length for org ${orgId}`, - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); sendOlmError( OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED, @@ -219,7 +222,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { } else if (policyCheck.policies?.requiredTwoFactor === false) { logger.warn( `[handleOlmRegisterMessage] Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}`, - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); sendOlmError( OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED, @@ -229,7 +232,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { } else if (!policyCheck.allowed) { logger.warn( `[handleOlmRegisterMessage] Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`, - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId); return; @@ -253,7 +256,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { // Prepare an array to store site configurations logger.debug( `[handleOlmRegisterMessage] Found ${sitesCount} sites for client ${client.clientId}`, - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); let jitMode = false; @@ -263,19 +266,20 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { // If we have too many sites we need to drop into fully JIT mode by not sending any of the sites logger.info( `[handleOlmRegisterMessage] Too many sites (${sitesCount}), dropping into JIT mode`, - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); jitMode = true; } logger.debug( `[handleOlmRegisterMessage] Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`, - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); if (!publicKey) { logger.warn("[handleOlmRegisterMessage] Public key not provided", { - orgId: client.orgId + orgId: client.orgId, + clientId: client.clientId }); return; } @@ -283,7 +287,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { if (client.pubKey !== publicKey || client.archived) { logger.info( "[handleOlmRegisterMessage] Public key mismatch. Updating public key and clearing session info...", - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); // Update the client's public key await db @@ -321,7 +325,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) { logger.warn( `[handleOlmRegisterMessage] Client last hole punch is too old and we have sites to send; skipping this register. The client is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`, - { orgId: client.orgId } + { orgId: client.orgId, clientId: client.clientId } ); return; }