Update schmea; create client when registering

This commit is contained in:
Owen
2025-11-03 15:42:22 -08:00
parent 43590896e9
commit d30743a428
5 changed files with 200 additions and 33 deletions

View File

@@ -607,6 +607,10 @@ export const clients = pgTable("clients", {
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null" onDelete: "set null"
}), }),
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
}),
name: varchar("name").notNull(), name: varchar("name").notNull(),
pubKey: varchar("pubKey"), pubKey: varchar("pubKey"),
subnet: varchar("subnet").notNull(), subnet: varchar("subnet").notNull(),
@@ -638,6 +642,11 @@ export const olms = pgTable("olms", {
dateCreated: varchar("dateCreated").notNull(), dateCreated: varchar("dateCreated").notNull(),
version: text("version"), version: text("version"),
clientId: integer("clientId").references(() => clients.clientId, { clientId: integer("clientId").references(() => clients.clientId, {
// we will switch this depending on the current org it wants to connect to
onDelete: "set null"
}),
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade" onDelete: "cascade"
}) })
}); });

View File

@@ -25,11 +25,10 @@ export const dnsRecords = sqliteTable("dnsRecords", {
recordType: text("recordType").notNull(), // "NS" | "CNAME" | "A" | "TXT" recordType: text("recordType").notNull(), // "NS" | "CNAME" | "A" | "TXT"
baseDomain: text("baseDomain"), baseDomain: text("baseDomain"),
value: text("value").notNull(), value: text("value").notNull(),
verified: integer("verified", { mode: "boolean" }).notNull().default(false), verified: integer("verified", { mode: "boolean" }).notNull().default(false)
}); });
export const orgs = sqliteTable("orgs", { export const orgs = sqliteTable("orgs", {
orgId: text("orgId").primaryKey(), orgId: text("orgId").primaryKey(),
name: text("name").notNull(), name: text("name").notNull(),
@@ -142,9 +141,10 @@ export const resources = sqliteTable("resources", {
onDelete: "set null" onDelete: "set null"
}), }),
headers: text("headers"), // comma-separated list of headers to add to the request headers: text("headers"), // comma-separated list of headers to add to the request
proxyProtocol: integer("proxyProtocol", { mode: "boolean" }).notNull().default(false), proxyProtocol: integer("proxyProtocol", { mode: "boolean" })
.notNull()
.default(false),
proxyProtocolVersion: integer("proxyProtocolVersion").default(1) proxyProtocolVersion: integer("proxyProtocolVersion").default(1)
}); });
export const targets = sqliteTable("targets", { export const targets = sqliteTable("targets", {
@@ -315,6 +315,10 @@ export const clients = sqliteTable("clients", {
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null" onDelete: "set null"
}), }),
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
}),
name: text("name").notNull(), name: text("name").notNull(),
pubKey: text("pubKey"), pubKey: text("pubKey"),
subnet: text("subnet").notNull(), subnet: text("subnet").notNull(),
@@ -347,6 +351,10 @@ export const olms = sqliteTable("olms", {
dateCreated: text("dateCreated").notNull(), dateCreated: text("dateCreated").notNull(),
version: text("version"), version: text("version"),
clientId: integer("clientId").references(() => clients.clientId, { clientId: integer("clientId").references(() => clients.clientId, {
// we will switch this depending on the current org it wants to connect to
onDelete: "set null"
}),
userId: text("userId").references(() => users.userId, { // optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade" onDelete: "cascade"
}) })
}); });

View File

@@ -197,7 +197,7 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
// // set the item in the database if it is offline // // set the item in the database if it is offline
// if (isActuallyOnline != node.online) { // if (isActuallyOnline != node.online) {
// await db // await trx
// .update(exitNodes) // .update(exitNodes)
// .set({ online: isActuallyOnline }) // .set({ online: isActuallyOnline })
// .where(eq(exitNodes.exitNodeId, node.exitNodeId)); // .where(eq(exitNodes.exitNodeId, node.exitNodeId));

View File

@@ -182,14 +182,13 @@ export async function createClient(
const randomExitNode = const randomExitNode =
exitNodesList[Math.floor(Math.random() * exitNodesList.length)]; exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
const adminRole = await trx const [adminRole] = await trx
.select() .select()
.from(roles) .from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1); .limit(1);
if (adminRole.length === 0) { if (!adminRole) {
trx.rollback();
return next( return next(
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`) createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
); );
@@ -207,12 +206,12 @@ export async function createClient(
.returning(); .returning();
await trx.insert(roleClients).values({ await trx.insert(roleClients).values({
roleId: adminRole[0].roleId, roleId: adminRole.roleId,
clientId: newClient.clientId clientId: newClient.clientId
}); });
if (req.user && req.userOrgRoleId != adminRole[0].roleId) { if (req.user && req.userOrgRoleId != adminRole.roleId) {
// make sure the user can access the site // make sure the user can access the client
trx.insert(userClients).values({ trx.insert(userClients).values({
userId: req.user?.userId!, userId: req.user?.userId!,
clientId: newClient.clientId clientId: newClient.clientId

View File

@@ -1,10 +1,22 @@
import { db, ExitNode } from "@server/db"; import {
Client,
db,
ExitNode,
orgs,
roleClients,
roles,
Transaction,
userClients,
userOrgs,
users
} from "@server/db";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db"; import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db";
import { and, eq, inArray } from "drizzle-orm"; import { and, eq, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers"; import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger"; import logger from "@server/logger";
import { listExitNodes } from "#dynamic/lib/exitNodes"; import { listExitNodes } from "#dynamic/lib/exitNodes";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
export const handleOlmRegisterMessage: MessageHandler = async (context) => { export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!"); logger.info("Handling register olm message!");
@@ -17,15 +29,62 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.warn("Olm not found"); logger.warn("Olm not found");
return; return;
} }
if (!olm.clientId) {
logger.warn("Olm has no client ID!"); const { publicKey, relay, olmVersion, orgId, deviceName } = message.data;
let client: Client;
if (orgId) {
if (!olm.userId) {
logger.warn("Olm has no user ID to verify org change!");
return;
}
try {
client = await getOrCreateOrgClient(orgId, olm.userId, deviceName);
} catch (err) {
logger.error(
`Error switching olm client ${olm.olmId} to org ${orgId}: ${err}`
);
return;
}
if (!client) {
logger.warn("Client not found");
return;
}
logger.debug(
`Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}`
);
await db
.update(olms)
.set({
clientId: client.clientId
})
.where(eq(olms.olmId, olm.olmId));
} else {
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
logger.debug(`Using last connected org for client ${olm.clientId}`);
[client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, olm.clientId))
.limit(1);
}
if (!client) {
logger.warn("Client ID not found");
return; return;
} }
const clientId = olm.clientId;
const { publicKey, relay, olmVersion } = message.data;
logger.debug( logger.debug(
`Olm client ID: ${clientId}, Public Key: ${publicKey}, Relay: ${relay}` `Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
); );
if (!publicKey) { if (!publicKey) {
@@ -33,18 +92,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return; return;
} }
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
if (client.exitNodeId) { if (client.exitNodeId) {
// TODO: FOR NOW WE ARE JUST HOLEPUNCHING ALL EXIT NODES BUT IN THE FUTURE WE SHOULD HANDLE THIS BETTER // TODO: FOR NOW WE ARE JUST HOLEPUNCHING ALL EXIT NODES BUT IN THE FUTURE WE SHOULD HANDLE THIS BETTER
@@ -103,7 +150,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.set({ .set({
pubKey: publicKey pubKey: publicKey
}) })
.where(eq(clients.clientId, olm.clientId)); .where(eq(clients.clientId, client.clientId));
// set isRelay to false for all of the client's sites to reset the connection metadata // set isRelay to false for all of the client's sites to reset the connection metadata
await db await db
@@ -111,7 +158,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.set({ .set({
isRelayed: relay == true isRelayed: relay == true
}) })
.where(eq(clientSites.clientId, olm.clientId)); .where(eq(clientSites.clientId, client.clientId));
} }
// Get all sites data // Get all sites data
@@ -145,7 +192,9 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// Validate endpoint and hole punch status // Validate endpoint and hole punch status
if (!site.endpoint) { if (!site.endpoint) {
logger.warn(`In olm register: site ${site.siteId} has no endpoint, skipping`); logger.warn(
`In olm register: site ${site.siteId} has no endpoint, skipping`
);
continue; continue;
} }
@@ -240,3 +289,105 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
excludeSender: false excludeSender: false
}; };
}; };
async function getOrCreateOrgClient(
orgId: string,
userId: string,
deviceName?: string,
trx: Transaction | typeof db = db
): Promise<Client> {
let client: Client;
// get the org
const [org] = await trx
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
throw new Error("Org not found");
}
if (!org.subnet) {
throw new Error("Org has no subnet defined");
}
// Verify that the user belongs to the org
const [userOrg] = await trx
.select()
.from(userOrgs)
.where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId)))
.limit(1);
if (!userOrg) {
throw new Error("User does not belong to org");
}
// check if the user has a client in the org and if not then create a client for them
const [existingClient] = await trx
.select()
.from(clients)
.where(and(eq(clients.orgId, orgId), eq(clients.userId, userId)))
.limit(1);
if (!existingClient) {
logger.debug(
`Client does not exist in org ${orgId}, creating new client for user ${userId}`
);
// TODO: more intelligent way to pick the exit node
const exitNodesList = await listExitNodes(orgId);
const randomExitNode =
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
const [adminRole] = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (!adminRole) {
throw new Error("Admin role not found");
}
const newSubnet = await getNextAvailableClientSubnet(orgId);
if (!newSubnet) {
throw new Error("No available subnet found");
}
const subnet = newSubnet.split("/")[0];
const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
const [newClient] = await trx
.insert(clients)
.values({
exitNodeId: randomExitNode.exitNodeId,
orgId,
name: deviceName || "User Device",
subnet: updatedSubnet,
type: "olm",
userId: userId
})
.returning();
await trx.insert(roleClients).values({
roleId: adminRole.roleId,
clientId: newClient.clientId
});
if (userOrg.roleId != adminRole.roleId) {
// make sure the user can access the client
trx.insert(userClients).values({
userId,
clientId: newClient.clientId
});
}
client = newClient;
} else {
client = existingClient;
}
return client;
}