Handle JIT for ssh

This commit is contained in:
Owen
2026-03-06 15:47:41 -08:00
parent 9405b0b70a
commit 0503c6e66e
6 changed files with 96 additions and 57 deletions

View File

@@ -720,6 +720,7 @@ export const clientSitesAssociationsCache = pgTable(
.notNull(), .notNull(),
siteId: integer("siteId").notNull(), siteId: integer("siteId").notNull(),
isRelayed: boolean("isRelayed").notNull().default(false), isRelayed: boolean("isRelayed").notNull().default(false),
isJitMode: boolean("isJitMode").notNull().default(false),
endpoint: varchar("endpoint"), endpoint: varchar("endpoint"),
publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
} }

View File

@@ -409,6 +409,9 @@ export const clientSitesAssociationsCache = sqliteTable(
isRelayed: integer("isRelayed", { mode: "boolean" }) isRelayed: integer("isRelayed", { mode: "boolean" })
.notNull() .notNull()
.default(false), .default(false),
isJitMode: integer("isJitMode", { mode: "boolean" })
.notNull()
.default(false),
endpoint: text("endpoint"), endpoint: text("endpoint"),
publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
} }

View File

@@ -63,6 +63,7 @@ export type SignSshKeyResponse = {
sshUsername: string; sshUsername: string;
sshHost: string; sshHost: string;
resourceId: number; resourceId: number;
siteId: number;
keyId: string; keyId: string;
validPrincipals: string[]; validPrincipals: string[];
validAfter: string; validAfter: string;
@@ -452,6 +453,7 @@ export async function signSshKey(
sshUsername: usernameToUse, sshUsername: usernameToUse,
sshHost: sshHost, sshHost: sshHost,
resourceId: resource.siteResourceId, resourceId: resource.siteResourceId,
siteId: resource.siteId,
keyId: cert.keyId, keyId: cert.keyId,
validPrincipals: cert.validPrincipals, validPrincipals: cert.validPrincipals,
validAfter: cert.validAfter.toISOString(), validAfter: cert.validAfter.toISOString(),

View File

@@ -1,4 +1,15 @@
import { clients, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, ExitNode, resources, Site, siteResources, targetHealthCheck, targets } from "@server/db"; import {
clients,
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
ExitNode,
resources,
Site,
siteResources,
targetHealthCheck,
targets
} from "@server/db";
import logger from "@server/logger"; import logger from "@server/logger";
import { initPeerAddHandshake, updatePeer } from "../olm/peers"; import { initPeerAddHandshake, updatePeer } from "../olm/peers";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
@@ -69,40 +80,42 @@ export async function buildClientConfigurationForNewtClient(
// ) // )
// ); // );
// update the peer info on the olm if (!client.clientSitesAssociationsCache.isJitMode) { // if we are adding sites through jit then dont add the site to the olm
// if the peer has not been added yet this will be a no-op // update the peer info on the olm
await updatePeer(client.clients.clientId, { // if the peer has not been added yet this will be a no-op
siteId: site.siteId, await updatePeer(client.clients.clientId, {
endpoint: site.endpoint!, siteId: site.siteId,
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`, endpoint: site.endpoint!,
publicKey: site.publicKey!, relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
serverIP: site.address, publicKey: site.publicKey!,
serverPort: site.listenPort serverIP: site.address,
// remoteSubnets: generateRemoteSubnets( serverPort: site.listenPort
// allSiteResources.map( // remoteSubnets: generateRemoteSubnets(
// ({ siteResources }) => siteResources // allSiteResources.map(
// ) // ({ siteResources }) => siteResources
// ), // )
// aliases: generateAliasConfig( // ),
// allSiteResources.map( // aliases: generateAliasConfig(
// ({ siteResources }) => siteResources // allSiteResources.map(
// ) // ({ siteResources }) => siteResources
// ) // )
}); // )
});
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch // also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op // if it has already been added this will be a no-op
await initPeerAddHandshake( await initPeerAddHandshake(
// this will kick off the add peer process for the client // this will kick off the add peer process for the client
client.clients.clientId, client.clients.clientId,
{ {
siteId, siteId,
exitNode: { exitNode: {
publicKey: exitNode.publicKey, publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint endpoint: exitNode.endpoint
}
} }
} );
); }
return { return {
publicKey: client.clients.pubKey!, publicKey: client.clients.pubKey!,

View File

@@ -209,6 +209,32 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} }
} }
// Get all sites data
const sitesCountResult = await db
.select({ count: count() })
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
let jitMode = false;
if (sitesCount > 250 && build == "saas") {
// THIS IS THE MAX ON THE BUSINESS TIER
// we have too many sites
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
logger.info("Too many sites (%d), dropping into JIT mode", sitesCount)
jitMode = true;
}
logger.debug( logger.debug(
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}` `Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
); );
@@ -235,28 +261,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
await db await db
.update(clientSitesAssociationsCache) .update(clientSitesAssociationsCache)
.set({ .set({
isRelayed: relay == true isRelayed: relay == true,
isJitMode: jitMode
}) })
.where(eq(clientSitesAssociationsCache.clientId, client.clientId)); .where(eq(clientSitesAssociationsCache.clientId, client.clientId));
} }
// Get all sites data
const sitesCountResult = await db
.select({ count: count() })
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
// this prevents us from accepting a register from an olm that has not hole punched yet. // this prevents us from accepting a register from an olm that has not hole punched yet.
// the olm will pump the register so we can keep checking // the olm will pump the register so we can keep checking
// TODO: I still think there is a better way to do this rather than locking it out here but ??? // TODO: I still think there is a better way to do this rather than locking it out here but ???
@@ -278,8 +288,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
aliases: Alias[]; aliases: Alias[];
}[] = []; }[] = [];
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites if (!jitMode) {
if (sitesCount <= 250 && build == "saas") { // THIS IS THE MAX ON THE BUSINESS TIER
// NOTE: its important that the client here is the old client and the public key is the new key // NOTE: its important that the client here is the old client and the public key is the new key
siteConfigurations = await buildSiteConfigurationForOlmClient( siteConfigurations = await buildSiteConfigurationForOlmClient(
client, client,

View File

@@ -1,8 +1,8 @@
import { sendToClient } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import { db, olms } from "@server/db"; import { clientSitesAssociationsCache, db, olms } from "@server/db";
import config from "@server/lib/config"; import config from "@server/lib/config";
import logger from "@server/logger"; import logger from "@server/logger";
import { eq } from "drizzle-orm"; import { and, eq } from "drizzle-orm";
import { Alias } from "yaml"; import { Alias } from "yaml";
export async function addPeer( export async function addPeer(
@@ -150,7 +150,7 @@ export async function initPeerAddHandshake(
}; };
}, },
olmId?: string, olmId?: string,
chainId?: string, chainId?: string
) { ) {
if (!olmId) { if (!olmId) {
const [olm] = await db const [olm] = await db
@@ -175,7 +175,7 @@ export async function initPeerAddHandshake(
relayPort: config.getRawConfig().gerbil.clients_start_port, relayPort: config.getRawConfig().gerbil.clients_start_port,
endpoint: peer.exitNode.endpoint endpoint: peer.exitNode.endpoint
}, },
chainId, chainId
} }
}, },
{ incrementConfigVersion: true } { incrementConfigVersion: true }
@@ -183,6 +183,17 @@ export async function initPeerAddHandshake(
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
// update the clientSiteAssociationsCache to make the isJitMode flag false so that JIT mode is disabled for this site if it restarts or something after the connection
await db
.update(clientSitesAssociationsCache)
.set({ isJitMode: false })
.where(
and(
eq(clientSitesAssociationsCache.clientId, clientId),
eq(clientSitesAssociationsCache.siteId, peer.siteId)
)
);
logger.info( logger.info(
`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}` `Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`
); );