Improve holepunching

This commit is contained in:
Owen
2025-12-01 11:51:20 -05:00
parent 8c62dfa706
commit a623604e96
16 changed files with 427 additions and 170 deletions

View File

@@ -36,13 +36,15 @@ export async function createSession(
const sessionId = encodeHexLowerCase( const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)) sha256(new TextEncoder().encode(token))
); );
const session: Session = { const [session] = await db
sessionId: sessionId, .insert(sessions)
userId, .values({
expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(), sessionId: sessionId,
issuedAt: new Date().getTime() userId,
}; expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(),
await db.insert(sessions).values(session); issuedAt: new Date().getTime()
})
.returning();
return session; return session;
} }

View File

@@ -288,7 +288,7 @@ export const sessions = pgTable("session", {
.references(() => users.userId, { onDelete: "cascade" }), .references(() => users.userId, { onDelete: "cascade" }),
expiresAt: bigint("expiresAt", { mode: "number" }).notNull(), expiresAt: bigint("expiresAt", { mode: "number" }).notNull(),
issuedAt: bigint("issuedAt", { mode: "number" }), issuedAt: bigint("issuedAt", { mode: "number" }),
deviceAuthUsed: boolean("deviceAuthUsed") deviceAuthUsed: boolean("deviceAuthUsed").notNull().default(false)
}); });
export const newtSessions = pgTable("newtSession", { export const newtSessions = pgTable("newtSession", {
@@ -665,7 +665,8 @@ export const clientSitesAssociationsCache = pgTable(
.notNull(), .notNull(),
siteId: integer("siteId").notNull(), siteId: integer("siteId").notNull(),
isRelayed: boolean("isRelayed").notNull().default(false), isRelayed: boolean("isRelayed").notNull().default(false),
endpoint: varchar("endpoint") endpoint: varchar("endpoint"),
publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
} }
); );

View File

@@ -1,6 +1,7 @@
import { randomUUID } from "crypto"; import { randomUUID } from "crypto";
import { InferSelectModel } from "drizzle-orm"; import { InferSelectModel } from "drizzle-orm";
import { sqliteTable, text, integer, index } from "drizzle-orm/sqlite-core"; import { sqliteTable, text, integer, index } from "drizzle-orm/sqlite-core";
import { no } from "zod/v4/locales";
export const domains = sqliteTable("domains", { export const domains = sqliteTable("domains", {
domainId: text("domainId").primaryKey(), domainId: text("domainId").primaryKey(),
@@ -372,7 +373,8 @@ export const clientSitesAssociationsCache = sqliteTable(
isRelayed: integer("isRelayed", { mode: "boolean" }) isRelayed: integer("isRelayed", { mode: "boolean" })
.notNull() .notNull()
.default(false), .default(false),
endpoint: text("endpoint") endpoint: text("endpoint"),
publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
} }
); );
@@ -417,6 +419,8 @@ export const sessions = sqliteTable("session", {
expiresAt: integer("expiresAt").notNull(), expiresAt: integer("expiresAt").notNull(),
issuedAt: integer("issuedAt"), issuedAt: integer("issuedAt"),
deviceAuthUsed: integer("deviceAuthUsed", { mode: "boolean" }) deviceAuthUsed: integer("deviceAuthUsed", { mode: "boolean" })
.notNull()
.default(false)
}); });
export const newtSessions = sqliteTable("newtSession", { export const newtSessions = sqliteTable("newtSession", {

View File

@@ -229,6 +229,11 @@ export const configSchema = z
.default(51820) .default(51820)
.transform(stoi) .transform(stoi)
.pipe(portSchema), .pipe(portSchema),
clients_start_port: portSchema
.optional()
.default(21820)
.transform(stoi)
.pipe(portSchema),
base_endpoint: z base_endpoint: z
.string() .string()
.optional() .optional()

View File

@@ -25,6 +25,7 @@ import {
deletePeer as newtDeletePeer deletePeer as newtDeletePeer
} from "@server/routers/newt/peers"; } from "@server/routers/newt/peers";
import { import {
initPeerAddHandshake as holepunchSiteAdd,
addPeer as olmAddPeer, addPeer as olmAddPeer,
deletePeer as olmDeletePeer deletePeer as olmDeletePeer
} from "@server/routers/olm/peers"; } from "@server/routers/olm/peers";
@@ -464,65 +465,16 @@ async function handleMessagesForSiteClients(
} }
if (isAdd) { if (isAdd) {
// TODO: WE NEED TO HANDLE THIS BETTER. WE ARE DEFAULTING TO RELAYING FOR NEW SITES await holepunchSiteAdd( // this will kick off the add peer process for the client
// BUT REALLY WE NEED TO TRACK THE USERS PREFERENCE THAT THEY CHOSE IN THE CLIENTS client.clientId,
// AND TRIGGER A HOLEPUNCH OR SOMETHING TO GET THE ENDPOINT AND HP TO THE NEW SITES {
const isRelayed = true;
newtJobs.push(
newtAddPeer(
siteId, siteId,
{ exitNode: {
publicKey: client.pubKey, publicKey: exitNode.publicKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client endpoint: exitNode.endpoint
// endpoint: isRelayed ? "" : clientSite.endpoint }
endpoint: isRelayed ? "" : "" // we are not HPing yet so no endpoint },
}, olm.olmId
newt.newtId
)
);
// TODO: should we have this here?
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
olmJobs.push(
olmAddPeer(
client.clientId,
{
siteId: site.siteId,
endpoint:
isRelayed || !site.endpoint
? `${exitNode.endpoint}:21820`
: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: generateRemoteSubnets(
allSiteResources.map(
({ siteResources }) => siteResources
)
)
},
olm.olmId
)
); );
} }

View File

@@ -1369,7 +1369,7 @@ const updateHolePunchSchema = z.object({
port: z.number(), port: z.number(),
timestamp: z.number(), timestamp: z.number(),
reachableAt: z.string().optional(), reachableAt: z.string().optional(),
publicKey: z.string().optional() publicKey: z.string() // this is the client public key
}); });
hybridRouter.post( hybridRouter.post(
"/gerbil/update-hole-punch", "/gerbil/update-hole-punch",
@@ -1408,7 +1408,7 @@ hybridRouter.post(
); );
} }
const { olmId, newtId, ip, port, timestamp, token, reachableAt } = const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } =
parsedParams.data; parsedParams.data;
const destinations = await updateAndGenerateEndpointDestinations( const destinations = await updateAndGenerateEndpointDestinations(
@@ -1418,6 +1418,7 @@ hybridRouter.post(
port, port,
timestamp, timestamp,
token, token,
publicKey,
exitNode, exitNode,
true true
); );

View File

@@ -30,8 +30,9 @@ const updateHolePunchSchema = z.object({
ip: z.string(), ip: z.string(),
port: z.number(), port: z.number(),
timestamp: z.number(), timestamp: z.number(),
publicKey: z.string(),
reachableAt: z.string().optional(), reachableAt: z.string().optional(),
publicKey: z.string().optional() exitNodePublicKey: z.string().optional()
}); });
// New response type with multi-peer destination support // New response type with multi-peer destination support
@@ -65,23 +66,26 @@ export async function updateHolePunch(
timestamp, timestamp,
token, token,
reachableAt, reachableAt,
publicKey publicKey, // this is the client's current public key for this session
exitNodePublicKey
} = parsedParams.data; } = parsedParams.data;
let exitNode: ExitNode | undefined; let exitNode: ExitNode | undefined;
if (publicKey) { if (exitNodePublicKey) {
// Get the exit node by public key // Get the exit node by public key
[exitNode] = await db [exitNode] = await db
.select() .select()
.from(exitNodes) .from(exitNodes)
.where(eq(exitNodes.publicKey, publicKey)); .where(eq(exitNodes.publicKey, exitNodePublicKey));
} else { } else {
// FOR BACKWARDS COMPATIBILITY IF GERBIL IS STILL =<1.1.0 // FOR BACKWARDS COMPATIBILITY IF GERBIL IS STILL =<1.1.0
[exitNode] = await db.select().from(exitNodes).limit(1); [exitNode] = await db.select().from(exitNodes).limit(1);
} }
if (!exitNode) { if (!exitNode) {
logger.warn(`Exit node not found for publicKey: ${publicKey}`); logger.warn(
`Exit node not found for publicKey: ${exitNodePublicKey}`
);
return next( return next(
createHttpError(HttpCode.NOT_FOUND, "Exit node not found") createHttpError(HttpCode.NOT_FOUND, "Exit node not found")
); );
@@ -94,12 +98,13 @@ export async function updateHolePunch(
port, port,
timestamp, timestamp,
token, token,
publicKey,
exitNode exitNode
); );
logger.debug( // logger.debug(
`Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}` // `Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}`
); // );
// Return the new multi-peer structure // Return the new multi-peer structure
return res.status(HttpCode.OK).send({ return res.status(HttpCode.OK).send({
@@ -123,6 +128,7 @@ export async function updateAndGenerateEndpointDestinations(
port: number, port: number,
timestamp: number, timestamp: number,
token: string, token: string,
publicKey: string,
exitNode: ExitNode, exitNode: ExitNode,
checkOrg = false checkOrg = false
) { ) {
@@ -130,9 +136,9 @@ export async function updateAndGenerateEndpointDestinations(
const destinations: PeerDestination[] = []; const destinations: PeerDestination[] = [];
if (olmId) { if (olmId) {
logger.debug( // logger.debug(
`Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}` // `Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}`
); // );
const { session, olm: olmSession } = const { session, olm: olmSession } =
await validateOlmSessionToken(token); await validateOlmSessionToken(token);
@@ -180,6 +186,7 @@ export async function updateAndGenerateEndpointDestinations(
siteId: sites.siteId, siteId: sites.siteId,
subnet: sites.subnet, subnet: sites.subnet,
listenPort: sites.listenPort, listenPort: sites.listenPort,
publicKey: sites.publicKey,
endpoint: clientSitesAssociationsCache.endpoint endpoint: clientSitesAssociationsCache.endpoint
}) })
.from(sites) .from(sites)
@@ -200,10 +207,19 @@ export async function updateAndGenerateEndpointDestinations(
`Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}` `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}`
); );
// if the public key or endpoint has changed, update it otherwise continue
if (
site.endpoint === `${ip}:${port}` &&
site.publicKey === publicKey
) {
continue;
}
const [updatedClientSitesAssociationsCache] = await db const [updatedClientSitesAssociationsCache] = await db
.update(clientSitesAssociationsCache) .update(clientSitesAssociationsCache)
.set({ .set({
endpoint: `${ip}:${port}` endpoint: `${ip}:${port}`,
publicKey: publicKey
}) })
.where( .where(
and( and(
@@ -227,9 +243,9 @@ export async function updateAndGenerateEndpointDestinations(
} }
} }
logger.debug( // logger.debug(
`Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}` // `Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}`
); // );
if (!updatedClient) { if (!updatedClient) {
logger.warn(`Client not found for olm: ${olmId}`); logger.warn(`Client not found for olm: ${olmId}`);
throw new Error("Client not found"); throw new Error("Client not found");
@@ -245,9 +261,9 @@ export async function updateAndGenerateEndpointDestinations(
} }
} }
} else if (newtId) { } else if (newtId) {
logger.debug( // logger.debug(
`Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}` // `Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}`
); // );
const { session, newt: newtSession } = const { session, newt: newtSession } =
await validateNewtSessionToken(token); await validateNewtSessionToken(token);
@@ -407,7 +423,7 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) {
{ {
siteId: siteId, siteId: siteId,
publicKey: site.publicKey, publicKey: site.publicKey,
endpoint: newEndpoint, endpoint: newEndpoint
}, },
client.olmId client.olmId
); );

View File

@@ -79,12 +79,12 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
// TODO: somehow we should make sure a recent hole punch has happened if this occurs (hole punch could be from the last restart if done quickly) // TODO: somehow we should make sure a recent hole punch has happened if this occurs (hole punch could be from the last restart if done quickly)
} }
// if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 6) { if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 5) {
// logger.warn( logger.warn(
// `Site ${existingSite.siteId} last hole punch is too old, skipping` `Site ${existingSite.siteId} last hole punch is too old, skipping`
// ); );
// return; return;
// } }
// update the endpoint and the public key // update the endpoint and the public key
const [site] = await db const [site] = await db

View File

@@ -1,9 +1,9 @@
import { generateSessionToken } from "@server/auth/sessions/app"; import { generateSessionToken } from "@server/auth/sessions/app";
import { db } from "@server/db"; import { clients, db, ExitNode, exitNodes, sites, clientSitesAssociationsCache } from "@server/db";
import { olms } from "@server/db"; import { olms } from "@server/db";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response"; import response from "@server/lib/response";
import { eq } from "drizzle-orm"; import { eq, inArray } from "drizzle-orm";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { z } from "zod"; import { z } from "zod";
@@ -15,11 +15,13 @@ import {
import { verifyPassword } from "@server/auth/password"; import { verifyPassword } from "@server/auth/password";
import logger from "@server/logger"; import logger from "@server/logger";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { listExitNodes } from "#dynamic/lib/exitNodes";
export const olmGetTokenBodySchema = z.object({ export const olmGetTokenBodySchema = z.object({
olmId: z.string(), olmId: z.string(),
secret: z.string(), secret: z.string(),
token: z.string().optional() token: z.string().optional(),
orgId: z.string().optional()
}); });
export type OlmGetTokenBody = z.infer<typeof olmGetTokenBodySchema>; export type OlmGetTokenBody = z.infer<typeof olmGetTokenBodySchema>;
@@ -40,7 +42,7 @@ export async function getOlmToken(
); );
} }
const { olmId, secret, token } = parsedBody.data; const { olmId, secret, token, orgId } = parsedBody.data;
try { try {
if (token) { if (token) {
@@ -61,11 +63,12 @@ export async function getOlmToken(
} }
} }
const existingOlmRes = await db const [existingOlm] = await db
.select() .select()
.from(olms) .from(olms)
.where(eq(olms.olmId, olmId)); .where(eq(olms.olmId, olmId));
if (!existingOlmRes || !existingOlmRes.length) {
if (!existingOlm) {
return next( return next(
createHttpError( createHttpError(
HttpCode.BAD_REQUEST, HttpCode.BAD_REQUEST,
@@ -74,12 +77,11 @@ export async function getOlmToken(
); );
} }
const existingOlm = existingOlmRes[0];
const validSecret = await verifyPassword( const validSecret = await verifyPassword(
secret, secret,
existingOlm.secretHash existingOlm.secretHash
); );
if (!validSecret) { if (!validSecret) {
if (config.getRawConfig().app.log_failed_attempts) { if (config.getRawConfig().app.log_failed_attempts) {
logger.info( logger.info(
@@ -96,11 +98,78 @@ export async function getOlmToken(
const resToken = generateSessionToken(); const resToken = generateSessionToken();
await createOlmSession(resToken, existingOlm.olmId); await createOlmSession(resToken, existingOlm.olmId);
let orgIdToUse = orgId;
if (!orgIdToUse) {
if (!existingOlm.clientId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Olm is not associated with a client, orgId is required"
)
);
}
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, existingOlm.clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Olm's associated client not found, orgId is required"
)
);
}
orgIdToUse = client.orgId;
}
// Get all exit nodes from sites where the client has peers
const clientSites = await db
.select()
.from(clientSitesAssociationsCache)
.innerJoin(
sites,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, existingOlm.clientId!));
// Extract unique exit node IDs
const exitNodeIds = Array.from(
new Set(
clientSites
.map(({ sites: site }) => site.exitNodeId)
.filter((id): id is number => id !== null)
)
);
let allExitNodes: ExitNode[] = [];
if (exitNodeIds.length > 0) {
allExitNodes = await db
.select()
.from(exitNodes)
.where(inArray(exitNodes.exitNodeId, exitNodeIds));
}
const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => {
return {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
};
});
logger.debug("Token created successfully"); logger.debug("Token created successfully");
return response<{ token: string }>(res, { return response<{
token: string;
exitNodes: { publicKey: string; endpoint: string }[];
}>(res, {
data: { data: {
token: resToken token: resToken,
exitNodes: exitNodesHpData
}, },
success: true, success: true,
error: false, error: false,

View File

@@ -59,7 +59,7 @@ export const startOlmOfflineChecker = (): void => {
// Send a disconnect message to the client if connected // Send a disconnect message to the client if connected
try { try {
await sendTerminateClient(offlineClient.clientId); // terminate first await sendTerminateClient(offlineClient.clientId, offlineClient.olmId); // terminate first
// wait a moment to ensure the message is sent // wait a moment to ensure the message is sent
await new Promise(resolve => setTimeout(resolve, 1000)); await new Promise(resolve => setTimeout(resolve, 1000));
await disconnectClient(offlineClient.olmId); await disconnectClient(offlineClient.olmId);

View File

@@ -34,21 +34,28 @@ import { generateRemoteSubnets } from "@server/lib/ip";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy"; import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app"; import { validateSessionToken } from "@server/auth/sessions/app";
import config from "@server/lib/config";
export const handleOlmRegisterMessage: MessageHandler = async (context) => { export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!"); logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context; const { message, client: c, sendToClient } = context;
const olm = c as Olm; const olm = c as Olm;
const now = new Date().getTime() / 1000; const now = Math.floor(Date.now() / 1000);
if (!olm) { if (!olm) {
logger.warn("Olm not found"); logger.warn("Olm not found");
return; return;
} }
const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient, token: userToken } = const {
message.data; publicKey,
relay,
olmVersion,
orgId,
doNotCreateNewClient,
token: userToken
} = message.data;
let client: Client | undefined; let client: Client | undefined;
let org: Org | undefined; let org: Org | undefined;
@@ -63,7 +70,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
olm.name || "User Device", olm.name || "User Device",
// doNotCreateNewClient ? true : false // doNotCreateNewClient ? true : false
true // for now never create a new client automatically because we create the users clients when they are added to the org true // for now never create a new client automatically because we create the users clients when they are added to the org
// this means that the rebuildClientAssociationsFromClient call below issue is not a problem // this means that the rebuildClientAssociationsFromClient call below issue is not a problem
); );
client = clientRes; client = clientRes;
@@ -113,12 +120,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
`Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}` `Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}`
); );
await db if (olm.clientId !== client.clientId) { // we only need to do this if the client is changing
.update(olms) await db
.set({ .update(olms)
clientId: client.clientId .set({
}) clientId: client.clientId
.where(eq(olms.olmId, olm.olmId)); })
.where(eq(olms.olmId, olm.olmId));
}
} else { } else {
if (!olm.clientId) { if (!olm.clientId) {
logger.warn("Olm has no client ID!"); logger.warn("Olm has no client ID!");
@@ -159,41 +168,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return; return;
} }
if (client.exitNodeId) { if (olmVersion && olm.version !== olmVersion) {
// TODO: FOR NOW WE ARE JUST HOLEPUNCHING ALL EXIT NODES BUT IN THE FUTURE WE SHOULD HANDLE THIS BETTER
// Get the exit node
const allExitNodes = await listExitNodes(client.orgId, true); // FILTER THE ONLINE ONES
const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => {
return {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
};
});
// Send holepunch message
await sendToClient(olm.olmId, {
type: "olm/wg/holepunch/all",
data: {
exitNodes: exitNodesHpData
}
});
if (!olmVersion) {
// THIS IS FOR BACKWARDS COMPATIBILITY
// THE OLDER CLIENTS DID NOT SEND THE VERSION
await sendToClient(olm.olmId, {
type: "olm/wg/holepunch",
data: {
serverPubKey: allExitNodes[0].publicKey,
endpoint: allExitNodes[0].endpoint
}
});
}
}
if (olmVersion) {
await db await db
.update(olms) .update(olms)
.set({ .set({
@@ -202,10 +177,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.where(eq(olms.olmId, olm.olmId)); .where(eq(olms.olmId, olm.olmId));
} }
// if (now - (client.lastHolePunch || 0) > 6) { // this prevents us from accepting a register from an olm that has not hole punched yet.
// logger.warn("Client last hole punch is too old, skipping all sites"); // the olm will pump the register so we can keep checking
// return; if (now - (client.lastHolePunch || 0) > 5) {
// } logger.warn(
"Client last hole punch is too old; skipping this register"
);
return;
}
if (client.pubKey !== publicKey) { if (client.pubKey !== publicKey) {
logger.info( logger.info(
@@ -319,7 +298,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.warn(`Exit node not found for site ${site.siteId}`); logger.warn(`Exit node not found for site ${site.siteId}`);
continue; continue;
} }
endpoint = `${exitNode.endpoint}:21820`; endpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
} }
const allSiteResources = await db // only get the site resources that this client has access to const allSiteResources = await db // only get the site resources that this client has access to

View File

@@ -2,7 +2,7 @@ import { db, exitNodes, sites } from "@server/db";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { clients, clientSitesAssociationsCache, Olm } from "@server/db"; import { clients, clientSitesAssociationsCache, Olm } from "@server/db";
import { and, eq } from "drizzle-orm"; import { and, eq } from "drizzle-orm";
import { updatePeer } from "../newt/peers"; import { updatePeer as newtUpdatePeer } from "../newt/peers";
import logger from "@server/logger"; import logger from "@server/logger";
export const handleOlmRelayMessage: MessageHandler = async (context) => { export const handleOlmRelayMessage: MessageHandler = async (context) => {
@@ -79,18 +79,20 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
); );
// update the peer on the exit node // update the peer on the exit node
await updatePeer(siteId, client.pubKey, { await newtUpdatePeer(siteId, client.pubKey, {
endpoint: "" // this removes the endpoint endpoint: "" // this removes the endpoint so the exit node knows to relay
}); });
sendToClient(olm.olmId, { return {
type: "olm/wg/peer/relay", message: {
type: "olm/wg/peer/relay",
data: { data: {
siteId: siteId, siteId: siteId,
endpoint: exitNode.endpoint, endpoint: exitNode.endpoint,
publicKey: exitNode.publicKey publicKey: exitNode.publicKey
} }
}); },
broadcast: false,
return; excludeSender: false
};
}; };

View File

@@ -0,0 +1,185 @@
import {
Client,
clientSiteResourcesAssociationsCache,
db,
ExitNode,
Org,
orgs,
roleClients,
roles,
siteResources,
Transaction,
userClients,
userOrgs,
users
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import {
clients,
clientSitesAssociationsCache,
exitNodes,
Olm,
olms,
sites
} from "@server/db";
import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import {
generateAliasConfig,
getNextAvailableClientSubnet
} from "@server/lib/ip";
import { generateRemoteSubnets } from "@server/lib/ip";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import {
addPeer as newtAddPeer,
deletePeer as newtDeletePeer
} from "@server/routers/newt/peers";
export const handleOlmServerPeerAddMessage: MessageHandler = async (
context
) => {
logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
const now = Math.floor(Date.now() / 1000);
if (!olm) {
logger.warn("Olm not found");
return;
}
const { siteId } = message.data;
// get the site
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${siteId} not found`
);
return;
}
if (!site.endpoint) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${siteId} has no endpoint`
);
return;
}
// get the client
if (!olm.clientId) {
logger.error(
`handleOlmServerPeerAddMessage: Olm with ID ${olm.olmId} has no clientId`
);
return;
}
const [client] = await db
.select()
.from(clients)
.where(and(eq(clients.clientId, olm.clientId)))
.limit(1);
if (!client) {
logger.error(
`handleOlmServerPeerAddMessage: Client with ID ${olm.clientId} not found`
);
return;
}
if (!client.pubKey) {
logger.error(
`handleOlmServerPeerAddMessage: Client with ID ${client.clientId} has no public key`
);
return;
}
let endpoint: string | null = null;
const currentSessionSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
isNotNull(clientSitesAssociationsCache.endpoint),
eq(clientSitesAssociationsCache.publicKey, client.pubKey) // limit it to the current session its connected with otherwise the endpoint could be stale
)
);
// pick an endpoint
for (const assoc of currentSessionSiteAssociationCaches) {
if (assoc.endpoint) {
endpoint = assoc.endpoint;
break;
}
}
if (!endpoint) {
logger.error(
`handleOlmServerPeerAddMessage: No endpoint found for client ${client.clientId}`
);
return;
}
await newtAddPeer(siteId, {
publicKey: client.pubKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
endpoint: endpoint // this is the client's endpoint with reference to the site's exit node
});
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
// Return connect message with all site configurations
return {
message: {
type: "olm/wg/peer/add",
data: {
siteId: site.siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: generateRemoteSubnets(
allSiteResources.map(({ siteResources }) => siteResources)
),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -7,3 +7,4 @@ export * from "./deleteUserOlm";
export * from "./listUserOlms"; export * from "./listUserOlms";
export * from "./deleteUserOlm"; export * from "./deleteUserOlm";
export * from "./getUserOlm"; export * from "./getUserOlm";
export * from "./handleOlmServerPeerAddMessage";

View File

@@ -3,6 +3,7 @@ import { clients, olms, newts, sites } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { sendToClient } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger"; import logger from "@server/logger";
import { exit } from "process";
export async function addPeer( export async function addPeer(
clientId: number, clientId: number,
@@ -110,3 +111,40 @@ export async function updatePeer(
logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`); logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`);
} }
export async function initPeerAddHandshake(
clientId: number,
peer: {
siteId: number;
exitNode: {
publicKey: string;
endpoint: string;
};
},
olmId?: string
) {
if (!olmId) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
olmId = olm.olmId;
}
await sendToClient(olmId, {
type: "olm/wg/peer/holepunch/site/add",
data: {
siteId: peer.siteId,
exitNode: {
publicKey: peer.exitNode.publicKey,
endpoint: peer.exitNode.endpoint
}
}
});
logger.info(`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`);
}

View File

@@ -11,23 +11,25 @@ import {
handleOlmRegisterMessage, handleOlmRegisterMessage,
handleOlmRelayMessage, handleOlmRelayMessage,
handleOlmPingMessage, handleOlmPingMessage,
startOlmOfflineChecker startOlmOfflineChecker,
handleOlmServerPeerAddMessage
} from "../olm"; } from "../olm";
import { handleHealthcheckStatusMessage } from "../target"; import { handleHealthcheckStatusMessage } from "../target";
import { MessageHandler } from "./types"; import { MessageHandler } from "./types";
export const messageHandlers: Record<string, MessageHandler> = { export const messageHandlers: Record<string, MessageHandler> = {
"newt/wg/register": handleNewtRegisterMessage, "olm/wg/server/peer/add": handleOlmServerPeerAddMessage,
"olm/wg/register": handleOlmRegisterMessage, "olm/wg/register": handleOlmRegisterMessage,
"newt/wg/get-config": handleGetConfigMessage,
"newt/receive-bandwidth": handleReceiveBandwidthMessage,
"olm/wg/relay": handleOlmRelayMessage, "olm/wg/relay": handleOlmRelayMessage,
"olm/ping": handleOlmPingMessage, "olm/ping": handleOlmPingMessage,
"newt/wg/register": handleNewtRegisterMessage,
"newt/wg/get-config": handleGetConfigMessage,
"newt/receive-bandwidth": handleReceiveBandwidthMessage,
"newt/socket/status": handleDockerStatusMessage, "newt/socket/status": handleDockerStatusMessage,
"newt/socket/containers": handleDockerContainersMessage, "newt/socket/containers": handleDockerContainersMessage,
"newt/ping/request": handleNewtPingRequestMessage, "newt/ping/request": handleNewtPingRequestMessage,
"newt/blueprint/apply": handleApplyBlueprintMessage, "newt/blueprint/apply": handleApplyBlueprintMessage,
"newt/healthcheck/status": handleHealthcheckStatusMessage, "newt/healthcheck/status": handleHealthcheckStatusMessage
}; };
startOlmOfflineChecker(); // this is to handle the offline check for olms startOlmOfflineChecker(); // this is to handle the offline check for olms