diff --git a/server/routers/gerbil/updateHolePunch.ts b/server/routers/gerbil/updateHolePunch.ts index ac3ff3945..ed60c6a40 100644 --- a/server/routers/gerbil/updateHolePunch.ts +++ b/server/routers/gerbil/updateHolePunch.ts @@ -411,12 +411,14 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) { return; } - // Get all non-relayed clients connected to this site + // Get all non-relayed and not jit clients connected to this site const connectedClients = await db .select({ + online: clients.online, clientId: clients.clientId, olmId: olms.olmId, - isRelayed: clientSitesAssociationsCache.isRelayed + isRelayed: clientSitesAssociationsCache.isRelayed, + isJitMode: clientSitesAssociationsCache.isJitMode }) .from(clientSitesAssociationsCache) .innerJoin( @@ -426,32 +428,36 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) { .innerJoin(olms, eq(olms.clientId, clients.clientId)) .where( and( + eq(clients.online, true), // the client has to be online or it does not matter... eq(clientSitesAssociationsCache.siteId, siteId), - eq(clientSitesAssociationsCache.isRelayed, false) + eq(clientSitesAssociationsCache.isRelayed, false), + eq(clientSitesAssociationsCache.isJitMode, false) ) ); - // Update each non-relayed client with the new site endpoint - for (const client of connectedClients) { - try { - await updateOlmPeer( - client.clientId, - { - siteId: siteId, - publicKey: site.publicKey, - endpoint: newEndpoint - }, - client.olmId - ); - logger.debug( - `Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}` - ); - } catch (error) { - logger.error( - `Failed to update client ${client.clientId} with new site endpoint: ${error}` - ); - } - } + // Update each non-relayed client with the new site endpoint (in parallel) + await Promise.allSettled( + connectedClients.map(async (client) => { + try { + await updateOlmPeer( + client.clientId, + { + siteId: siteId, + publicKey: site.publicKey!, + endpoint: newEndpoint + }, + client.olmId + ); + logger.debug( + `Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}` + ); + } catch (error) { + logger.error( + `Failed to update client ${client.clientId} with new site endpoint: ${error}` + ); + } + }) + ); } catch (error) { logger.error( `Error handling site endpoint change for site ${siteId}: ${error}` @@ -498,6 +504,7 @@ async function handleClientEndpointChange( // TODO: I THINK WE DONT NEED TO HIT ) .where( and( + eq(sites.online, true), // the site has to be online or it does not matter... eq(clientSitesAssociationsCache.clientId, clientId), eq(clientSitesAssociationsCache.isRelayed, false), eq(clientSitesAssociationsCache.isJitMode, false) diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index 5b8411eb7..51cdc8b4e 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -8,7 +8,7 @@ import { ExitNode, exitNodes, sites, - clientSitesAssociationsCache, + clientSitesAssociationsCache } from "@server/db"; import { olms } from "@server/db"; import HttpCode from "@server/types/HttpCode"; @@ -28,6 +28,7 @@ import { verifyPassword } from "@server/auth/password"; import logger from "@server/logger"; import config from "@server/lib/config"; import { APP_VERSION } from "@server/lib/consts"; +import { build } from "@server/build"; export const olmGetTokenBodySchema = z.object({ olmId: z.string(), @@ -220,6 +221,22 @@ export async function getOlmToken( ) .where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!)); + if (clientSites.length > 250 && build == "saas") { + // set all of the cache rows isJitMode to true + await db + .update(clientSitesAssociationsCache) + .set({ isJitMode: true }) + .where( + and( + eq( + clientSitesAssociationsCache.clientId, + clientIdToUse! + ), + eq(clientSitesAssociationsCache.isJitMode, false) + ) + ); + } + // Extract unique exit node IDs const exitNodeIds = Array.from( new Set( diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index aaaf81c41..073e86f7b 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -7,7 +7,7 @@ import { olms, sites } from "@server/db"; -import { count, eq } from "drizzle-orm"; +import { and, count, eq, ne, or } from "drizzle-orm"; import logger from "@server/logger"; import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy"; import { validateSessionToken } from "@server/auth/sessions/app"; @@ -301,7 +301,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { isRelayed: relay == true, isJitMode: jitMode }) - .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); + .where( + and( + eq(clientSitesAssociationsCache.clientId, client.clientId), + or( + ne( + clientSitesAssociationsCache.isRelayed, + relay == true + ), + ne(clientSitesAssociationsCache.isJitMode, jitMode) + ) + ) + ); } // this prevents us from accepting a register from an olm that has not hole punched yet.