From c4f48f5748af5eba3045ccba01a56d3f0cba78bf Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 22 Mar 2026 14:29:47 -0700 Subject: [PATCH] WIP - more conversion --- .../handleOlmServerInitAddPeerHandshake.ts | 240 +++++++++--------- .../olm/handleOlmServerPeerAddMessage.ts | 42 ++- 2 files changed, 135 insertions(+), 147 deletions(-) diff --git a/server/routers/olm/handleOlmServerInitAddPeerHandshake.ts b/server/routers/olm/handleOlmServerInitAddPeerHandshake.ts index 54badb2dc..0eda41e04 100644 --- a/server/routers/olm/handleOlmServerInitAddPeerHandshake.ts +++ b/server/routers/olm/handleOlmServerInitAddPeerHandshake.ts @@ -4,10 +4,12 @@ import { db, exitNodes, Site, - siteResources + siteNetworks, + siteResources, + sites } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; -import { clients, Olm, sites } from "@server/db"; +import { clients, Olm } from "@server/db"; import { and, eq, or } from "drizzle-orm"; import logger from "@server/logger"; import { initPeerAddHandshake } from "./peers"; @@ -44,20 +46,31 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async ( const { siteId, resourceId, chainId } = message.data; - let site: Site | null = null; + const sendCancel = async () => { + await sendToClient( + olm.olmId, + { + type: "olm/wg/peer/chain/cancel", + data: { chainId } + }, + { incrementConfigVersion: false } + ).catch((error) => { + logger.warn(`Error sending message:`, error); + }); + }; + + let sitesToProcess: Site[] = []; + if (siteId) { - // get the site const [siteRes] = await db .select() .from(sites) .where(eq(sites.siteId, siteId)) .limit(1); if (siteRes) { - site = siteRes; + sitesToProcess = [siteRes]; } - } - - if (resourceId && !site) { + } else if (resourceId) { const resources = await db .select() .from(siteResources) @@ -72,27 +85,17 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async ( ); if (!resources || resources.length === 0) { - logger.error(`handleOlmServerPeerAddMessage: Resource not found`); - // cancel the request from the olm side to not keep doing this - await sendToClient( - olm.olmId, - { - type: "olm/wg/peer/chain/cancel", - data: { - chainId - } - }, - { incrementConfigVersion: false } - ).catch((error) => { - logger.warn(`Error sending message:`, error); - }); + logger.error( + `handleOlmServerInitAddPeerHandshake: Resource not found` + ); + await sendCancel(); return; } if (resources.length > 1) { // error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches logger.error( - `handleOlmServerPeerAddMessage: Multiple resources found matching the criteria` + `handleOlmServerInitAddPeerHandshake: Multiple resources found matching the criteria` ); return; } @@ -117,125 +120,120 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async ( if (currentResourceAssociationCaches.length === 0) { logger.error( - `handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}` + `handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}` ); - // cancel the request from the olm side to not keep doing this - await sendToClient( - olm.olmId, - { - type: "olm/wg/peer/chain/cancel", - data: { - chainId - } - }, - { incrementConfigVersion: false } - ).catch((error) => { - logger.warn(`Error sending message:`, error); - }); + await sendCancel(); return; } - const siteIdFromResource = resource.siteId; - - // get the site - const [siteRes] = await db - .select() - .from(sites) - .where(eq(sites.siteId, siteIdFromResource)); - if (!siteRes) { + if (!resource.networkId) { logger.error( - `handleOlmServerPeerAddMessage: Site with ID ${site} not found` + `handleOlmServerInitAddPeerHandshake: Resource ${resource.siteResourceId} has no network` ); + await sendCancel(); return; } - site = siteRes; + // Get all sites associated with this resource's network via siteNetworks + const siteRows = await db + .select({ siteId: siteNetworks.siteId }) + .from(siteNetworks) + .where(eq(siteNetworks.networkId, resource.networkId)); + + if (!siteRows || siteRows.length === 0) { + logger.error( + `handleOlmServerInitAddPeerHandshake: No sites found for resource ${resource.siteResourceId}` + ); + await sendCancel(); + return; + } + + // Fetch full site objects for all network members + const foundSites = await Promise.all( + siteRows.map(async ({ siteId: sid }) => { + const [s] = await db + .select() + .from(sites) + .where(eq(sites.siteId, sid)) + .limit(1); + return s ?? null; + }) + ); + + sitesToProcess = foundSites.filter((s): s is Site => s !== null); } - if (!site) { - logger.error(`handleOlmServerPeerAddMessage: Site not found`); + if (sitesToProcess.length === 0) { + logger.error( + `handleOlmServerInitAddPeerHandshake: No sites to process` + ); + await sendCancel(); return; } - // check if the client can access this site using the cache - const currentSiteAssociationCaches = await db - .select() - .from(clientSitesAssociationsCache) - .where( - and( - eq(clientSitesAssociationsCache.clientId, client.clientId), - eq(clientSitesAssociationsCache.siteId, site.siteId) - ) - ); + let handshakeInitiated = false; - if (currentSiteAssociationCaches.length === 0) { - logger.error( - `handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}` - ); - // cancel the request from the olm side to not keep doing this - await sendToClient( - olm.olmId, + for (const site of sitesToProcess) { + // Check if the client can access this site using the cache + const currentSiteAssociationCaches = await db + .select() + .from(clientSitesAssociationsCache) + .where( + and( + eq(clientSitesAssociationsCache.clientId, client.clientId), + eq(clientSitesAssociationsCache.siteId, site.siteId) + ) + ); + + if (currentSiteAssociationCaches.length === 0) { + logger.warn( + `handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to site ${site.siteId}, skipping` + ); + continue; + } + + if (!site.exitNodeId) { + logger.error( + `handleOlmServerInitAddPeerHandshake: Site ${site.siteId} has no exit node, skipping` + ); + continue; + } + + const [exitNode] = await db + .select() + .from(exitNodes) + .where(eq(exitNodes.exitNodeId, site.exitNodeId)); + + if (!exitNode) { + logger.error( + `handleOlmServerInitAddPeerHandshake: Exit node not found for site ${site.siteId}, skipping` + ); + continue; + } + + // Trigger the peer add handshake — if the peer was already added this will be a no-op + await initPeerAddHandshake( + client.clientId, { - type: "olm/wg/peer/chain/cancel", - data: { - chainId + siteId: site.siteId, + exitNode: { + publicKey: exitNode.publicKey, + endpoint: exitNode.endpoint } }, - { incrementConfigVersion: false } - ).catch((error) => { - logger.warn(`Error sending message:`, error); - }); - return; - } - - if (!site.exitNodeId) { - logger.error( - `handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node` - ); - // cancel the request from the olm side to not keep doing this - await sendToClient( olm.olmId, - { - type: "olm/wg/peer/chain/cancel", - data: { - chainId - } - }, - { incrementConfigVersion: false } - ).catch((error) => { - logger.warn(`Error sending message:`, error); - }); - return; - } - - // get the exit node from the side - const [exitNode] = await db - .select() - .from(exitNodes) - .where(eq(exitNodes.exitNodeId, site.exitNodeId)); - - if (!exitNode) { - logger.error( - `handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node` + chainId ); - return; + + handshakeInitiated = true; } - // also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch - // if it has already been added this will be a no-op - await initPeerAddHandshake( - // this will kick off the add peer process for the client - client.clientId, - { - siteId: site.siteId, - exitNode: { - publicKey: exitNode.publicKey, - endpoint: exitNode.endpoint - } - }, - olm.olmId, - chainId - ); + if (!handshakeInitiated) { + logger.error( + `handleOlmServerInitAddPeerHandshake: No accessible sites with valid exit nodes found, cancelling chain` + ); + await sendCancel(); + } return; -}; +}; \ No newline at end of file diff --git a/server/routers/olm/handleOlmServerPeerAddMessage.ts b/server/routers/olm/handleOlmServerPeerAddMessage.ts index 64284f493..5f46ea84c 100644 --- a/server/routers/olm/handleOlmServerPeerAddMessage.ts +++ b/server/routers/olm/handleOlmServerPeerAddMessage.ts @@ -1,43 +1,25 @@ import { - Client, clientSiteResourcesAssociationsCache, db, - ExitNode, - Org, - orgs, - roleClients, - roles, + networks, + siteNetworks, siteResources, - Transaction, - userClients, - userOrgs, - users } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; import { clients, clientSitesAssociationsCache, - exitNodes, Olm, - olms, sites } from "@server/db"; import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm"; -import { addPeer, deletePeer } from "../newt/peers"; import logger from "@server/logger"; -import { listExitNodes } from "#dynamic/lib/exitNodes"; import { generateAliasConfig, - getNextAvailableClientSubnet } from "@server/lib/ip"; import { generateRemoteSubnets } from "@server/lib/ip"; -import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; -import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy"; -import { validateSessionToken } from "@server/auth/sessions/app"; -import config from "@server/lib/config"; import { addPeer as newtAddPeer, - deletePeer as newtDeletePeer } from "@server/routers/newt/peers"; export const handleOlmServerPeerAddMessage: MessageHandler = async ( @@ -153,13 +135,21 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async ( clientSiteResourcesAssociationsCache.siteResourceId ) ) - .where( + .innerJoin( + networks, + eq(siteResources.networkId, networks.networkId) + ) + .innerJoin( + siteNetworks, and( - eq(siteResources.siteId, site.siteId), - eq( - clientSiteResourcesAssociationsCache.clientId, - client.clientId - ) + eq(networks.networkId, siteNetworks.networkId), + eq(siteNetworks.siteId, site.siteId) + ) + ) + .where( + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId ) );