Compare commits

...

6 Commits

Author SHA1 Message Date
Owen
0503c6e66e Handle JIT for ssh 2026-03-06 15:49:17 -08:00
Owen
9405b0b70a Force jit above site limit 2026-03-06 14:09:57 -08:00
Owen
2a5c9465e9 Add chainId field passthrough 2026-03-04 22:17:58 -08:00
Owen
f36b66e397 Merge branch 'dev' into jit 2026-03-04 17:58:50 -08:00
Owen
1bfff630bf Jit working for sites 2026-03-04 17:46:58 -08:00
Owen
c73a39f797 Allow JIT based on site or resource 2026-03-04 15:44:27 -08:00
13 changed files with 374 additions and 77 deletions

View File

@@ -720,6 +720,7 @@ export const clientSitesAssociationsCache = pgTable(
.notNull(), .notNull(),
siteId: integer("siteId").notNull(), siteId: integer("siteId").notNull(),
isRelayed: boolean("isRelayed").notNull().default(false), isRelayed: boolean("isRelayed").notNull().default(false),
isJitMode: boolean("isJitMode").notNull().default(false),
endpoint: varchar("endpoint"), endpoint: varchar("endpoint"),
publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
} }

View File

@@ -409,6 +409,9 @@ export const clientSitesAssociationsCache = sqliteTable(
isRelayed: integer("isRelayed", { mode: "boolean" }) isRelayed: integer("isRelayed", { mode: "boolean" })
.notNull() .notNull()
.default(false), .default(false),
isJitMode: integer("isJitMode", { mode: "boolean" })
.notNull()
.default(false),
endpoint: text("endpoint"), endpoint: text("endpoint"),
publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
} }

View File

@@ -29,7 +29,6 @@ import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { eq, or, and } from "drizzle-orm"; import { eq, or, and } from "drizzle-orm";
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource"; import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
import { signPublicKey, getOrgCAKeys } from "@server/lib/sshCA"; import { signPublicKey, getOrgCAKeys } from "@server/lib/sshCA";
@@ -64,6 +63,7 @@ export type SignSshKeyResponse = {
sshUsername: string; sshUsername: string;
sshHost: string; sshHost: string;
resourceId: number; resourceId: number;
siteId: number;
keyId: string; keyId: string;
validPrincipals: string[]; validPrincipals: string[];
validAfter: string; validAfter: string;
@@ -453,6 +453,7 @@ export async function signSshKey(
sshUsername: usernameToUse, sshUsername: usernameToUse,
sshHost: sshHost, sshHost: sshHost,
resourceId: resource.siteResourceId, resourceId: resource.siteResourceId,
siteId: resource.siteId,
keyId: cert.keyId, keyId: cert.keyId,
validPrincipals: cert.validPrincipals, validPrincipals: cert.validPrincipals,
validAfter: cert.validAfter.toISOString(), validAfter: cert.validAfter.toISOString(),

View File

@@ -76,7 +76,7 @@ const processMessage = async (
clientId, clientId,
message.type, // Pass message type for granular limiting message.type, // Pass message type for granular limiting
100, // max requests per window 100, // max requests per window
20, // max requests per message type per window 100, // max requests per message type per window
60 * 1000 // window in milliseconds 60 * 1000 // window in milliseconds
); );
if (rateLimitResult.isLimited) { if (rateLimitResult.isLimited) {

View File

@@ -1,4 +1,15 @@
import { clients, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, ExitNode, resources, Site, siteResources, targetHealthCheck, targets } from "@server/db"; import {
clients,
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
ExitNode,
resources,
Site,
siteResources,
targetHealthCheck,
targets
} from "@server/db";
import logger from "@server/logger"; import logger from "@server/logger";
import { initPeerAddHandshake, updatePeer } from "../olm/peers"; import { initPeerAddHandshake, updatePeer } from "../olm/peers";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
@@ -69,40 +80,42 @@ export async function buildClientConfigurationForNewtClient(
// ) // )
// ); // );
// update the peer info on the olm if (!client.clientSitesAssociationsCache.isJitMode) { // if we are adding sites through jit then dont add the site to the olm
// if the peer has not been added yet this will be a no-op // update the peer info on the olm
await updatePeer(client.clients.clientId, { // if the peer has not been added yet this will be a no-op
siteId: site.siteId, await updatePeer(client.clients.clientId, {
endpoint: site.endpoint!, siteId: site.siteId,
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`, endpoint: site.endpoint!,
publicKey: site.publicKey!, relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
serverIP: site.address, publicKey: site.publicKey!,
serverPort: site.listenPort serverIP: site.address,
// remoteSubnets: generateRemoteSubnets( serverPort: site.listenPort
// allSiteResources.map( // remoteSubnets: generateRemoteSubnets(
// ({ siteResources }) => siteResources // allSiteResources.map(
// ) // ({ siteResources }) => siteResources
// ), // )
// aliases: generateAliasConfig( // ),
// allSiteResources.map( // aliases: generateAliasConfig(
// ({ siteResources }) => siteResources // allSiteResources.map(
// ) // ({ siteResources }) => siteResources
// ) // )
}); // )
});
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch // also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op // if it has already been added this will be a no-op
await initPeerAddHandshake( await initPeerAddHandshake(
// this will kick off the add peer process for the client // this will kick off the add peer process for the client
client.clients.clientId, client.clients.clientId,
{ {
siteId, siteId,
exitNode: { exitNode: {
publicKey: exitNode.publicKey, publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint endpoint: exitNode.endpoint
}
} }
} );
); }
return { return {
publicKey: client.clients.pubKey!, publicKey: client.clients.pubKey!,

View File

@@ -17,6 +17,8 @@ import { getUserDeviceName } from "@server/db/names";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration"; import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { OlmErrorCodes, sendOlmError } from "./error"; import { OlmErrorCodes, sendOlmError } from "./error";
import { handleFingerprintInsertion } from "./fingerprintingUtils"; import { handleFingerprintInsertion } from "./fingerprintingUtils";
import { Alias } from "@server/lib/ip";
import { build } from "@server/build";
export const handleOlmRegisterMessage: MessageHandler = async (context) => { export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!"); logger.info("Handling register olm message!");
@@ -207,6 +209,32 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} }
} }
// Get all sites data
const sitesCountResult = await db
.select({ count: count() })
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
let jitMode = false;
if (sitesCount > 250 && build == "saas") {
// THIS IS THE MAX ON THE BUSINESS TIER
// we have too many sites
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
logger.info("Too many sites (%d), dropping into JIT mode", sitesCount)
jitMode = true;
}
logger.debug( logger.debug(
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}` `Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
); );
@@ -233,28 +261,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
await db await db
.update(clientSitesAssociationsCache) .update(clientSitesAssociationsCache)
.set({ .set({
isRelayed: relay == true isRelayed: relay == true,
isJitMode: jitMode
}) })
.where(eq(clientSitesAssociationsCache.clientId, client.clientId)); .where(eq(clientSitesAssociationsCache.clientId, client.clientId));
} }
// Get all sites data
const sitesCountResult = await db
.select({ count: count() })
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
// this prevents us from accepting a register from an olm that has not hole punched yet. // this prevents us from accepting a register from an olm that has not hole punched yet.
// the olm will pump the register so we can keep checking // the olm will pump the register so we can keep checking
// TODO: I still think there is a better way to do this rather than locking it out here but ??? // TODO: I still think there is a better way to do this rather than locking it out here but ???
@@ -265,18 +277,25 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return; return;
} }
// NOTE: its important that the client here is the old client and the public key is the new key let siteConfigurations: {
const siteConfigurations = await buildSiteConfigurationForOlmClient( siteId: number;
client, name: string;
publicKey, endpoint: string;
relay publicKey: string | null;
); serverIP: string | null;
serverPort: number | null;
remoteSubnets: string[];
aliases: Alias[];
}[] = [];
// REMOVED THIS SO IT CREATES THE INTERFACE AND JUST WAITS FOR THE SITES if (!jitMode) {
// if (siteConfigurations.length === 0) { // NOTE: its important that the client here is the old client and the public key is the new key
// logger.warn("No valid site configurations found"); siteConfigurations = await buildSiteConfigurationForOlmClient(
// return; client,
// } publicKey,
relay
);
}
// Return connect message with all site configurations // Return connect message with all site configurations
return { return {

View File

@@ -18,7 +18,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
} }
if (!olm.clientId) { if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here? logger.warn("Olm has no client!");
return; return;
} }
@@ -41,7 +41,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
return; return;
} }
const { siteId } = message.data; const { siteId, chainId } = message.data;
// Get the site // Get the site
const [site] = await db const [site] = await db
@@ -90,7 +90,8 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
data: { data: {
siteId: siteId, siteId: siteId,
relayEndpoint: exitNode.endpoint, relayEndpoint: exitNode.endpoint,
relayPort: config.getRawConfig().gerbil.clients_start_port relayPort: config.getRawConfig().gerbil.clients_start_port,
chainId
} }
}, },
broadcast: false, broadcast: false,

View File

@@ -0,0 +1,241 @@
import {
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
exitNodes,
Site,
siteResources
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, Olm, sites } from "@server/db";
import { and, eq, or } from "drizzle-orm";
import logger from "@server/logger";
import { initPeerAddHandshake } from "./peers";
export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
context
) => {
logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client!"); // TODO: Maybe we create the site here?
return;
}
const clientId = olm.clientId;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
const { siteId, resourceId, chainId } = message.data;
let site: Site | null = null;
if (siteId) {
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (siteRes) {
site = siteRes;
}
}
if (resourceId && !site) {
const resources = await db
.select()
.from(siteResources)
.where(
and(
or(
eq(siteResources.niceId, resourceId),
eq(siteResources.alias, resourceId)
),
eq(siteResources.orgId, client.orgId)
)
);
if (!resources || resources.length === 0) {
logger.error(`handleOlmServerPeerAddMessage: Resource not found`);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
if (resources.length > 1) {
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
logger.error(
`handleOlmServerPeerAddMessage: Multiple resources found matching the criteria`
);
return;
}
const resource = resources[0];
const currentResourceAssociationCaches = await db
.select()
.from(clientSiteResourcesAssociationsCache)
.where(
and(
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
resource.siteResourceId
),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
if (currentResourceAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
const siteIdFromResource = resource.siteId;
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteIdFromResource));
if (!siteRes) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site} not found`
);
return;
}
site = siteRes;
}
if (!site) {
logger.error(`handleOlmServerPeerAddMessage: Site not found`);
return;
}
// check if the client can access this site using the cache
const currentSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
);
if (currentSiteAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
if (!site.exitNodeId) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
// get the exit node from the side
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
if (!exitNode) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
);
return;
}
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
{
siteId: site.siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
},
olm.olmId,
chainId
);
return;
};

View File

@@ -54,7 +54,7 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
return; return;
} }
const { siteId } = message.data; const { siteId, chainId } = message.data;
// get the site // get the site
const [site] = await db const [site] = await db
@@ -179,7 +179,8 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
), ),
aliases: generateAliasConfig( aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources) allSiteResources.map(({ siteResources }) => siteResources)
) ),
chainId: chainId,
} }
}, },
broadcast: false, broadcast: false,

View File

@@ -17,7 +17,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
} }
if (!olm.clientId) { if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here? logger.warn("Olm has no client!");
return; return;
} }
@@ -40,7 +40,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
return; return;
} }
const { siteId } = message.data; const { siteId, chainId } = message.data;
// Get the site // Get the site
const [site] = await db const [site] = await db
@@ -87,7 +87,8 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
type: "olm/wg/peer/unrelay", type: "olm/wg/peer/unrelay",
data: { data: {
siteId: siteId, siteId: siteId,
endpoint: site.endpoint endpoint: site.endpoint,
chainId
} }
}, },
broadcast: false, broadcast: false,

View File

@@ -11,3 +11,4 @@ export * from "./handleOlmServerPeerAddMessage";
export * from "./handleOlmUnRelayMessage"; export * from "./handleOlmUnRelayMessage";
export * from "./recoverOlmWithFingerprint"; export * from "./recoverOlmWithFingerprint";
export * from "./handleOlmDisconnectingMessage"; export * from "./handleOlmDisconnectingMessage";
export * from "./handleOlmServerInitAddPeerHandshake";

View File

@@ -1,8 +1,8 @@
import { sendToClient } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import { db, olms } from "@server/db"; import { clientSitesAssociationsCache, db, olms } from "@server/db";
import config from "@server/lib/config"; import config from "@server/lib/config";
import logger from "@server/logger"; import logger from "@server/logger";
import { eq } from "drizzle-orm"; import { and, eq } from "drizzle-orm";
import { Alias } from "yaml"; import { Alias } from "yaml";
export async function addPeer( export async function addPeer(
@@ -149,7 +149,8 @@ export async function initPeerAddHandshake(
endpoint: string; endpoint: string;
}; };
}, },
olmId?: string olmId?: string,
chainId?: string
) { ) {
if (!olmId) { if (!olmId) {
const [olm] = await db const [olm] = await db
@@ -173,7 +174,8 @@ export async function initPeerAddHandshake(
publicKey: peer.exitNode.publicKey, publicKey: peer.exitNode.publicKey,
relayPort: config.getRawConfig().gerbil.clients_start_port, relayPort: config.getRawConfig().gerbil.clients_start_port,
endpoint: peer.exitNode.endpoint endpoint: peer.exitNode.endpoint
} },
chainId
} }
}, },
{ incrementConfigVersion: true } { incrementConfigVersion: true }
@@ -181,6 +183,17 @@ export async function initPeerAddHandshake(
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
// update the clientSiteAssociationsCache to make the isJitMode flag false so that JIT mode is disabled for this site if it restarts or something after the connection
await db
.update(clientSitesAssociationsCache)
.set({ isJitMode: false })
.where(
and(
eq(clientSitesAssociationsCache.clientId, clientId),
eq(clientSitesAssociationsCache.siteId, peer.siteId)
)
);
logger.info( logger.info(
`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}` `Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`
); );

View File

@@ -15,7 +15,8 @@ import {
startOlmOfflineChecker, startOlmOfflineChecker,
handleOlmServerPeerAddMessage, handleOlmServerPeerAddMessage,
handleOlmUnRelayMessage, handleOlmUnRelayMessage,
handleOlmDisconnecingMessage handleOlmDisconnecingMessage,
handleOlmServerInitAddPeerHandshake
} from "../olm"; } from "../olm";
import { handleHealthcheckStatusMessage } from "../target"; import { handleHealthcheckStatusMessage } from "../target";
import { handleRoundTripMessage } from "./handleRoundTripMessage"; import { handleRoundTripMessage } from "./handleRoundTripMessage";
@@ -23,6 +24,7 @@ import { MessageHandler } from "./types";
export const messageHandlers: Record<string, MessageHandler> = { export const messageHandlers: Record<string, MessageHandler> = {
"olm/wg/server/peer/add": handleOlmServerPeerAddMessage, "olm/wg/server/peer/add": handleOlmServerPeerAddMessage,
"olm/wg/server/peer/init": handleOlmServerInitAddPeerHandshake,
"olm/wg/register": handleOlmRegisterMessage, "olm/wg/register": handleOlmRegisterMessage,
"olm/wg/relay": handleOlmRelayMessage, "olm/wg/relay": handleOlmRelayMessage,
"olm/wg/unrelay": handleOlmUnRelayMessage, "olm/wg/unrelay": handleOlmUnRelayMessage,