diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index 504ea761..63bb0535 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -720,6 +720,7 @@ export const clientSitesAssociationsCache = pgTable( .notNull(), siteId: integer("siteId").notNull(), isRelayed: boolean("isRelayed").notNull().default(false), + isJitMode: boolean("isJitMode").notNull().default(false), endpoint: varchar("endpoint"), publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes } diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index 2bd11ee0..6183dfd7 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -409,6 +409,9 @@ export const clientSitesAssociationsCache = sqliteTable( isRelayed: integer("isRelayed", { mode: "boolean" }) .notNull() .default(false), + isJitMode: integer("isJitMode", { mode: "boolean" }) + .notNull() + .default(false), endpoint: text("endpoint"), publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes } diff --git a/server/private/routers/ssh/signSshKey.ts b/server/private/routers/ssh/signSshKey.ts index d6fe88eb..e7095181 100644 --- a/server/private/routers/ssh/signSshKey.ts +++ b/server/private/routers/ssh/signSshKey.ts @@ -63,6 +63,7 @@ export type SignSshKeyResponse = { sshUsername: string; sshHost: string; resourceId: number; + siteId: number; keyId: string; validPrincipals: string[]; validAfter: string; @@ -452,6 +453,7 @@ export async function signSshKey( sshUsername: usernameToUse, sshHost: sshHost, resourceId: resource.siteResourceId, + siteId: resource.siteId, keyId: cert.keyId, validPrincipals: cert.validPrincipals, validAfter: cert.validAfter.toISOString(), diff --git a/server/routers/newt/buildConfiguration.ts b/server/routers/newt/buildConfiguration.ts index 65cb18a2..3c55d6b9 100644 --- a/server/routers/newt/buildConfiguration.ts +++ b/server/routers/newt/buildConfiguration.ts @@ -1,4 +1,15 @@ -import { clients, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, ExitNode, resources, Site, siteResources, targetHealthCheck, targets } from "@server/db"; +import { + clients, + clientSiteResourcesAssociationsCache, + clientSitesAssociationsCache, + db, + ExitNode, + resources, + Site, + siteResources, + targetHealthCheck, + targets +} from "@server/db"; import logger from "@server/logger"; import { initPeerAddHandshake, updatePeer } from "../olm/peers"; import { eq, and } from "drizzle-orm"; @@ -69,40 +80,42 @@ export async function buildClientConfigurationForNewtClient( // ) // ); - // update the peer info on the olm - // if the peer has not been added yet this will be a no-op - await updatePeer(client.clients.clientId, { - siteId: site.siteId, - endpoint: site.endpoint!, - relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`, - publicKey: site.publicKey!, - serverIP: site.address, - serverPort: site.listenPort - // remoteSubnets: generateRemoteSubnets( - // allSiteResources.map( - // ({ siteResources }) => siteResources - // ) - // ), - // aliases: generateAliasConfig( - // allSiteResources.map( - // ({ siteResources }) => siteResources - // ) - // ) - }); + if (!client.clientSitesAssociationsCache.isJitMode) { // if we are adding sites through jit then dont add the site to the olm + // update the peer info on the olm + // if the peer has not been added yet this will be a no-op + await updatePeer(client.clients.clientId, { + siteId: site.siteId, + endpoint: site.endpoint!, + relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`, + publicKey: site.publicKey!, + serverIP: site.address, + serverPort: site.listenPort + // remoteSubnets: generateRemoteSubnets( + // allSiteResources.map( + // ({ siteResources }) => siteResources + // ) + // ), + // aliases: generateAliasConfig( + // allSiteResources.map( + // ({ siteResources }) => siteResources + // ) + // ) + }); - // also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch - // if it has already been added this will be a no-op - await initPeerAddHandshake( - // this will kick off the add peer process for the client - client.clients.clientId, - { - siteId, - exitNode: { - publicKey: exitNode.publicKey, - endpoint: exitNode.endpoint + // also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch + // if it has already been added this will be a no-op + await initPeerAddHandshake( + // this will kick off the add peer process for the client + client.clients.clientId, + { + siteId, + exitNode: { + publicKey: exitNode.publicKey, + endpoint: exitNode.endpoint + } } - } - ); + ); + } return { publicKey: client.clients.pubKey!, diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 1e209d13..91a2aa13 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -209,6 +209,32 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { } } + // Get all sites data + const sitesCountResult = await db + .select({ count: count() }) + .from(sites) + .innerJoin( + clientSitesAssociationsCache, + eq(sites.siteId, clientSitesAssociationsCache.siteId) + ) + .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); + + // Extract the count value from the result array + const sitesCount = + sitesCountResult.length > 0 ? sitesCountResult[0].count : 0; + + // Prepare an array to store site configurations + logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`); + + let jitMode = false; + if (sitesCount > 250 && build == "saas") { + // THIS IS THE MAX ON THE BUSINESS TIER + // we have too many sites + // If we have too many sites we need to drop into fully JIT mode by not sending any of the sites + logger.info("Too many sites (%d), dropping into JIT mode", sitesCount) + jitMode = true; + } + logger.debug( `Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}` ); @@ -235,28 +261,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { await db .update(clientSitesAssociationsCache) .set({ - isRelayed: relay == true + isRelayed: relay == true, + isJitMode: jitMode }) .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); } - // Get all sites data - const sitesCountResult = await db - .select({ count: count() }) - .from(sites) - .innerJoin( - clientSitesAssociationsCache, - eq(sites.siteId, clientSitesAssociationsCache.siteId) - ) - .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); - - // Extract the count value from the result array - const sitesCount = - sitesCountResult.length > 0 ? sitesCountResult[0].count : 0; - - // Prepare an array to store site configurations - logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`); - // this prevents us from accepting a register from an olm that has not hole punched yet. // the olm will pump the register so we can keep checking // TODO: I still think there is a better way to do this rather than locking it out here but ??? @@ -278,8 +288,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { aliases: Alias[]; }[] = []; - // If we have too many sites we need to drop into fully JIT mode by not sending any of the sites - if (sitesCount <= 250 && build == "saas") { // THIS IS THE MAX ON THE BUSINESS TIER + if (!jitMode) { // NOTE: its important that the client here is the old client and the public key is the new key siteConfigurations = await buildSiteConfigurationForOlmClient( client, diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index 66453008..06621cac 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -1,8 +1,8 @@ import { sendToClient } from "#dynamic/routers/ws"; -import { db, olms } from "@server/db"; +import { clientSitesAssociationsCache, db, olms } from "@server/db"; import config from "@server/lib/config"; import logger from "@server/logger"; -import { eq } from "drizzle-orm"; +import { and, eq } from "drizzle-orm"; import { Alias } from "yaml"; export async function addPeer( @@ -150,7 +150,7 @@ export async function initPeerAddHandshake( }; }, olmId?: string, - chainId?: string, + chainId?: string ) { if (!olmId) { const [olm] = await db @@ -175,7 +175,7 @@ export async function initPeerAddHandshake( relayPort: config.getRawConfig().gerbil.clients_start_port, endpoint: peer.exitNode.endpoint }, - chainId, + chainId } }, { incrementConfigVersion: true } @@ -183,6 +183,17 @@ export async function initPeerAddHandshake( logger.warn(`Error sending message:`, error); }); + // update the clientSiteAssociationsCache to make the isJitMode flag false so that JIT mode is disabled for this site if it restarts or something after the connection + await db + .update(clientSitesAssociationsCache) + .set({ isJitMode: false }) + .where( + and( + eq(clientSitesAssociationsCache.clientId, clientId), + eq(clientSitesAssociationsCache.siteId, peer.siteId) + ) + ); + logger.info( `Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}` );