Compare commits

..

1 Commits

Author SHA1 Message Date
Owen Schwartz
293d9865b4 Merge pull request #3014 from fosrl/dev
1.18.3
2026-05-06 16:30:36 -07:00
23 changed files with 286 additions and 466 deletions

View File

@@ -25,9 +25,9 @@ import { tierMatrix } from "./billing/tierMatrix";
export async function calculateUserClientsForOrgs( export async function calculateUserClientsForOrgs(
userId: string, userId: string,
trx: Transaction | typeof db = db trx?: Transaction
): Promise<void> { ): Promise<void> {
const execute = async (transaction: Transaction | typeof db) => { const execute = async (transaction: Transaction) => {
const orgCache = new Map<string, typeof orgs.$inferSelect | null>(); const orgCache = new Map<string, typeof orgs.$inferSelect | null>();
const adminRoleCache = new Map< const adminRoleCache = new Map<
string, string,
@@ -437,7 +437,7 @@ export async function calculateUserClientsForOrgs(
async function cleanupOrphanedClients( async function cleanupOrphanedClients(
userId: string, userId: string,
trx: Transaction | typeof db, trx: Transaction,
userOrgIds: string[] = [] userOrgIds: string[] = []
): Promise<void> { ): Promise<void> {
// Find all OLM clients for this user that should be deleted // Find all OLM clients for this user that should be deleted

View File

@@ -124,7 +124,7 @@ export function computeBuckets(
let totalDowntime = 0; let totalDowntime = 0;
for (let d = 0; d < days; d++) { for (let d = 0; d < days; d++) {
const dayStartSec = todayMidnightSec - (days - 1 - d) * 86400; const dayStartSec = todayMidnightSec - (days - d) * 86400;
const dayEndSec = dayStartSec + 86400; const dayEndSec = dayStartSec + 86400;
const dayEvents = events.filter( const dayEvents = events.filter(

View File

@@ -485,133 +485,6 @@ async function syncAcmeCertsFromHttp(endpoint: string): Promise<void> {
} }
} }
async function storeCertForDomain(
domain: string,
certPem: string,
keyPem: string,
validatedX509: crypto.X509Certificate
): Promise<void> {
const wildcard = domain.startsWith("*.");
const existing = await db
.select()
.from(certificates)
.where(eq(certificates.domain, domain))
.limit(1);
let oldCertPem: string | null = null;
let oldKeyPem: string | null = null;
if (existing.length > 0 && existing[0].certFile) {
try {
const storedCertPem = decrypt(
existing[0].certFile,
config.getRawConfig().server.secret!
);
const wildcardUnchanged = existing[0].wildcard === wildcard;
if (storedCertPem === certPem && wildcardUnchanged) {
return;
}
oldCertPem = storedCertPem;
if (existing[0].keyFile) {
try {
oldKeyPem = decrypt(
existing[0].keyFile,
config.getRawConfig().server.secret!
);
} catch (keyErr) {
logger.debug(
`acmeCertSync: could not decrypt stored key for ${domain}: ${keyErr}`
);
}
}
} catch (err) {
logger.debug(
`acmeCertSync: could not decrypt stored cert for ${domain}, will update: ${err}`
);
}
}
let expiresAt: number | null = null;
try {
expiresAt = Math.floor(
new Date(validatedX509.validTo).getTime() / 1000
);
} catch (err) {
logger.debug(
`acmeCertSync: could not parse cert expiry for ${domain}: ${err}`
);
}
const encryptedCert = encrypt(
certPem,
config.getRawConfig().server.secret!
);
const encryptedKey = encrypt(keyPem, config.getRawConfig().server.secret!);
const now = Math.floor(Date.now() / 1000);
const domainId = await findDomainId(domain);
if (domainId) {
logger.debug(
`acmeCertSync: resolved domainId "${domainId}" for cert domain "${domain}"`
);
} else {
logger.debug(
`acmeCertSync: no matching domain record found for cert domain "${domain}"`
);
}
if (existing.length > 0) {
logger.debug(
`acmeCertSync: updating existing certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db
.update(certificates)
.set({
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
updatedAt: now,
wildcard,
...(domainId !== null && { domainId })
})
.where(eq(certificates.domain, domain));
logger.debug(
`acmeCertSync: updated certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await pushCertUpdateToAffectedNewts(
domain,
domainId,
oldCertPem,
oldKeyPem
);
} else {
logger.debug(
`acmeCertSync: inserting new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db.insert(certificates).values({
domain,
domainId,
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
createdAt: now,
updatedAt: now,
wildcard
});
logger.debug(
`acmeCertSync: inserted new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await pushCertUpdateToAffectedNewts(domain, domainId, null, null);
}
}
function findAcmeJsonFiles(dirPath: string): string[] { function findAcmeJsonFiles(dirPath: string): string[] {
const results: string[] = []; const results: string[] = [];
let entries: fs.Dirent[]; let entries: fs.Dirent[];
@@ -702,16 +575,18 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
} }
for (const cert of allCerts) { for (const cert of allCerts) {
const mainDomain = cert?.domain?.main; const domain = cert?.domain?.main;
if (!mainDomain || typeof mainDomain !== "string") { if (!domain || typeof domain !== "string") {
logger.debug(`acmeCertSync: skipping cert with missing domain`); logger.debug(`acmeCertSync: skipping cert with missing domain`);
continue; continue;
} }
const { wildcard } = detectWildcard(domain, cert.domain?.sans);
if (!cert.certificate || !cert.key) { if (!cert.certificate || !cert.key) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${mainDomain} - empty certificate or key field` `acmeCertSync: skipping cert for ${domain} - empty certificate or key field`
); );
continue; continue;
} }
@@ -723,14 +598,14 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
keyPem = Buffer.from(cert.key, "base64").toString("utf8"); keyPem = Buffer.from(cert.key, "base64").toString("utf8");
} catch (err) { } catch (err) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${mainDomain} - failed to base64-decode cert/key: ${err}` `acmeCertSync: skipping cert for ${domain} - failed to base64-decode cert/key: ${err}`
); );
continue; continue;
} }
if (!certPem.trim() || !keyPem.trim()) { if (!certPem.trim() || !keyPem.trim()) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${mainDomain} - blank PEM after base64 decode` `acmeCertSync: skipping cert for ${domain} - blank PEM after base64 decode`
); );
continue; continue;
} }
@@ -741,7 +616,7 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
const firstCertPemForValidation = extractFirstCert(certPem); const firstCertPemForValidation = extractFirstCert(certPem);
if (!firstCertPemForValidation) { if (!firstCertPemForValidation) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${mainDomain} - no PEM certificate block found` `acmeCertSync: skipping cert for ${domain} - no PEM certificate block found`
); );
continue; continue;
} }
@@ -753,7 +628,7 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
); );
} catch (err) { } catch (err) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${mainDomain} - invalid X.509 certificate: ${err}` `acmeCertSync: skipping cert for ${domain} - invalid X.509 certificate: ${err}`
); );
continue; continue;
} }
@@ -763,40 +638,139 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
crypto.createPrivateKey(keyPem); crypto.createPrivateKey(keyPem);
} catch (err) { } catch (err) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${mainDomain} - invalid private key: ${err}` `acmeCertSync: skipping cert for ${domain} - invalid private key: ${err}`
); );
continue; continue;
} }
// Collect all domains covered by this cert: main + every SAN. // Check if cert already exists in DB
// Each domain gets its own row in the certificates table so that const existing = await db
// lookups by any hostname on the cert succeed independently. .select()
const allDomains = new Set<string>([mainDomain]); .from(certificates)
if (Array.isArray(cert.domain?.sans)) { .where(and(eq(certificates.domain, domain)))
for (const san of cert.domain.sans) { .limit(1);
if (typeof san === "string" && san.trim()) {
allDomains.add(san.trim()); let oldCertPem: string | null = null;
let oldKeyPem: string | null = null;
if (existing.length > 0 && existing[0].certFile) {
try {
const storedCertPem = decrypt(
existing[0].certFile,
config.getRawConfig().server.secret!
);
const wildcardUnchanged = existing[0].wildcard === wildcard;
if (storedCertPem === certPem && wildcardUnchanged) {
// logger.debug(
// `acmeCertSync: cert for ${domain} is unchanged, skipping`
// );
continue;
} }
// Cert has changed; capture old values so we can send a correct
// update message to the newt after the DB write.
oldCertPem = storedCertPem;
if (existing[0].keyFile) {
try {
oldKeyPem = decrypt(
existing[0].keyFile,
config.getRawConfig().server.secret!
);
} catch (keyErr) {
logger.debug(
`acmeCertSync: could not decrypt stored key for ${domain}: ${keyErr}`
);
}
}
} catch (err) {
// Decryption failure means we should proceed with the update
logger.debug(
`acmeCertSync: could not decrypt stored cert for ${domain}, will update: ${err}`
);
} }
} }
logger.debug( // Parse cert expiry from the validated X.509 certificate
`acmeCertSync: cert for ${mainDomain} covers ${allDomains.size} domain(s): ${[...allDomains].join(", ")}` let expiresAt: number | null = null;
); try {
expiresAt = Math.floor(
new Date(validatedX509.validTo).getTime() / 1000
);
} catch (err) {
logger.debug(
`acmeCertSync: could not parse cert expiry for ${domain}: ${err}`
);
}
for (const domain of allDomains) { const encryptedCert = encrypt(
try { certPem,
await storeCertForDomain( config.getRawConfig().server.secret!
domain, );
certPem, const encryptedKey = encrypt(
keyPem, keyPem,
validatedX509 config.getRawConfig().server.secret!
); );
} catch (err) { const now = Math.floor(Date.now() / 1000);
logger.error(
`acmeCertSync: error storing cert for domain "${domain}": ${err}` const domainId = await findDomainId(domain);
); if (domainId) {
} logger.debug(
`acmeCertSync: resolved domainId "${domainId}" for cert domain "${domain}"`
);
} else {
logger.debug(
`acmeCertSync: no matching domain record found for cert domain "${domain}"`
);
}
if (existing.length > 0) {
logger.debug(
`acmeCertSync: updating existing certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db
.update(certificates)
.set({
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
updatedAt: now,
wildcard,
...(domainId !== null && { domainId })
})
.where(eq(certificates.domain, domain));
logger.debug(
`acmeCertSync: updated certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await pushCertUpdateToAffectedNewts(
domain,
domainId,
oldCertPem,
oldKeyPem
);
} else {
logger.debug(
`acmeCertSync: inserting new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db.insert(certificates).values({
domain,
domainId,
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
createdAt: now,
updatedAt: now,
wildcard
});
logger.debug(
`acmeCertSync: inserted new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
// For a brand-new cert, push to any SSL resources that were waiting for it
await pushCertUpdateToAffectedNewts(domain, domainId, null, null);
} }
} }
} }

View File

@@ -14,7 +14,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import stoi from "@server/lib/stoi"; import stoi from "@server/lib/stoi";
import { clients, db, primaryDb, Client } from "@server/db"; import { clients, db } from "@server/db";
import { userOrgRoles, userOrgs, roles } from "@server/db"; import { userOrgRoles, userOrgs, roles } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -122,12 +122,8 @@ export async function addUserRole(
); );
} }
let newUserRole: { let newUserRole: { userId: string; orgId: string; roleId: number } | null =
userId: string; null;
orgId: string;
roleId: number;
} | null = null;
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const inserted = await trx const inserted = await trx
.insert(userOrgRoles) .insert(userOrgRoles)
@@ -153,19 +149,11 @@ export async function addUserRole(
) )
); );
orgClientsToRebuild = orgClients; for (const orgClient of orgClients) {
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after adding role: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: newUserRole ?? { userId, orgId: role.orgId, roleId }, data: newUserRole ?? { userId, orgId: role.orgId, roleId },
success: true, success: true,

View File

@@ -14,7 +14,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import stoi from "@server/lib/stoi"; import stoi from "@server/lib/stoi";
import { db, primaryDb, Client } from "@server/db"; import { db } from "@server/db";
import { userOrgRoles, userOrgs, roles, clients } from "@server/db"; import { userOrgRoles, userOrgs, roles, clients } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -129,7 +129,6 @@ export async function removeUserRole(
} }
} }
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
.delete(userOrgRoles) .delete(userOrgRoles)
@@ -151,19 +150,11 @@ export async function removeUserRole(
) )
); );
orgClientsToRebuild = orgClients; for (const orgClient of orgClients) {
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after removing role: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: { userId, orgId: role.orgId, roleId }, data: { userId, orgId: role.orgId, roleId },
success: true, success: true,

View File

@@ -13,7 +13,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { clients, db, primaryDb, Client } from "@server/db"; import { clients, db } from "@server/db";
import { userOrgRoles, userOrgs, roles } from "@server/db"; import { userOrgRoles, userOrgs, roles } from "@server/db";
import { eq, and, inArray } from "drizzle-orm"; import { eq, and, inArray } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -115,7 +115,6 @@ export async function setUserOrgRoles(
); );
} }
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
.delete(userOrgRoles) .delete(userOrgRoles)
@@ -143,19 +142,11 @@ export async function setUserOrgRoles(
and(eq(clients.userId, userId), eq(clients.orgId, orgId)) and(eq(clients.userId, userId), eq(clients.orgId, orgId))
); );
orgClientsToRebuild = orgClients; for (const orgClient of orgClients) {
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after setting roles: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: { userId, orgId, roleIds: uniqueRoleIds }, data: { userId, orgId, roleIds: uniqueRoleIds },
success: true, success: true,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, orgs, userOrgs, users, primaryDb } from "@server/db"; import { db, orgs, userOrgs, users } from "@server/db";
import { eq, and, inArray, not } from "drizzle-orm"; import { eq, and, inArray, not } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -218,18 +218,13 @@ export async function deleteMyAccount(
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx.delete(users).where(eq(users.userId, userId)); await trx.delete(users).where(eq(users.userId, userId));
await calculateUserClientsForOrgs(userId, trx);
// loop through the other orgs and decrement the count // loop through the other orgs and decrement the count
for (const userOrg of otherOrgsTheUserWasIn) { for (const userOrg of otherOrgsTheUserWasIn) {
await usageService.add(userOrg.orgId, FeatureId.USERS, -1, trx); await usageService.add(userOrg.orgId, FeatureId.USERS, -1, trx);
} }
}); });
calculateUserClientsForOrgs(userId, primaryDb).catch((e) => {
logger.error(
`Failed to calculate user clients after deleting account for user ${userId}: ${e}`
);
});
try { try {
await invalidateSession(session.sessionId); await invalidateSession(session.sessionId);
} catch (error) { } catch (error) {

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, primaryDb } from "@server/db"; import { db } from "@server/db";
import { import {
roles, roles,
Client, Client,
@@ -92,10 +92,7 @@ export async function createClient(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if ( if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) {
req.user &&
(!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)
) {
return next( return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role") createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
); );
@@ -201,10 +198,7 @@ export async function createClient(
if (!randomExitNode) { if (!randomExitNode) {
return next( return next(
createHttpError( createHttpError(HttpCode.NOT_FOUND, `No exit nodes available. ${build == "saas" ? "Please contact support." : "You need to install gerbil to use the clients."}`)
HttpCode.NOT_FOUND,
`No exit nodes available. ${build == "saas" ? "Please contact support." : "You need to install gerbil to use the clients."}`
)
); );
} }
@@ -262,17 +256,9 @@ export async function createClient(
clientId: newClient.clientId, clientId: newClient.clientId,
dateCreated: moment().toISOString() dateCreated: moment().toISOString()
}); });
});
if (newClient) { await rebuildClientAssociationsFromClient(newClient, trx);
rebuildClientAssociationsFromClient(newClient, primaryDb).catch( });
(e) => {
logger.error(
`Failed to rebuild client associations after creating client: ${e}`
);
}
);
}
return response<CreateClientResponse>(res, { return response<CreateClientResponse>(res, {
data: newClient, data: newClient,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, primaryDb } from "@server/db"; import { db } from "@server/db";
import { import {
roles, roles,
Client, Client,
@@ -237,17 +237,9 @@ export async function createUserClient(
userId, userId,
clientId: newClient.clientId clientId: newClient.clientId
}); });
});
if (newClient) { await rebuildClientAssociationsFromClient(newClient, trx);
rebuildClientAssociationsFromClient(newClient, primaryDb).catch( });
(e) => {
logger.error(
`Failed to rebuild client associations after creating user client: ${e}`
);
}
);
}
return response<CreateClientAndOlmResponse>(res, { return response<CreateClientAndOlmResponse>(res, {
data: newClient, data: newClient,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, olms, primaryDb, Client, Olm } from "@server/db"; import { db, olms } from "@server/db";
import { clients, clientSitesAssociationsCache } from "@server/db"; import { clients, clientSitesAssociationsCache } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -71,17 +71,14 @@ export async function deleteClient(
); );
} }
let deletedClient: Client | undefined;
let olm: Olm | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// Then delete the client itself // Then delete the client itself
[deletedClient] = await trx const [deletedClient] = await trx
.delete(clients) .delete(clients)
.where(eq(clients.clientId, clientId)) .where(eq(clients.clientId, clientId))
.returning(); .returning();
[olm] = await trx const [olm] = await trx
.select() .select()
.from(olms) .from(olms)
.where(eq(olms.clientId, clientId)) .where(eq(olms.clientId, clientId))
@@ -91,28 +88,13 @@ export async function deleteClient(
if (!client.userId && client.olmId) { if (!client.userId && client.olmId) {
await trx.delete(olms).where(eq(olms.olmId, client.olmId)); await trx.delete(olms).where(eq(olms.olmId, client.olmId));
} }
});
if (deletedClient) { await rebuildClientAssociationsFromClient(deletedClient, trx);
rebuildClientAssociationsFromClient(deletedClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations after deleting client ${clientId}: ${e}`
);
}
);
if (olm) { if (olm) {
sendTerminateClient( await sendTerminateClient(deletedClient.clientId, OlmErrorCodes.TERMINATED_DELETED, olm.olmId); // the olmId needs to be provided because it cant look it up after deletion
deletedClient.clientId,
OlmErrorCodes.TERMINATED_DELETED,
olm.olmId
).catch((e) => {
logger.error(
`Failed to send terminate message for client ${deletedClient?.clientId} after deleting client ${clientId}: ${e}`
);
});
} }
} });
return response(res, { return response(res, {
data: null, data: null,

View File

@@ -1,5 +1,5 @@
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import { db, olms, primaryDb } from "@server/db"; import { db, olms } from "@server/db";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { z } from "zod"; import { z } from "zod";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
@@ -81,19 +81,16 @@ export async function createUserOlm(
const secretHash = await hashPassword(secret); const secretHash = await hashPassword(secret);
await db.insert(olms).values({ await db.transaction(async (trx) => {
olmId: olmId, await trx.insert(olms).values({
userId, olmId: olmId,
name, userId,
secretHash, name,
dateCreated: moment().toISOString() secretHash,
}); dateCreated: moment().toISOString()
});
calculateUserClientsForOrgs(userId, primaryDb).catch((e) => { await calculateUserClientsForOrgs(userId, trx);
console.error(
"Error calculating user clients after creating olm:",
e
);
}); });
return response<CreateOlmResponse>(res, { return response<CreateOlmResponse>(res, {

View File

@@ -1,5 +1,5 @@
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import { Client, db, Olm, primaryDb } from "@server/db"; import { Client, db } from "@server/db";
import { olms, clients, clientSitesAssociationsCache } from "@server/db"; import { olms, clients, clientSitesAssociationsCache } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -49,7 +49,6 @@ export async function deleteUserOlm(
const { olmId } = parsedParams.data; const { olmId } = parsedParams.data;
let deletedClient: Client | undefined;
// Delete associated clients and the OLM in a transaction // Delete associated clients and the OLM in a transaction
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// Find all clients associated with this OLM // Find all clients associated with this OLM
@@ -58,6 +57,7 @@ export async function deleteUserOlm(
.from(clients) .from(clients)
.where(eq(clients.olmId, olmId)); .where(eq(clients.olmId, olmId));
let deletedClient: Client | null = null;
// Delete all associated clients // Delete all associated clients
if (associatedClients.length > 0) { if (associatedClients.length > 0) {
[deletedClient] = await trx [deletedClient] = await trx
@@ -67,27 +67,22 @@ export async function deleteUserOlm(
} }
// Finally, delete the OLM itself // Finally, delete the OLM itself
await trx.delete(olms).where(eq(olms.olmId, olmId)).returning(); const [olm] = await trx
}); .delete(olms)
.where(eq(olms.olmId, olmId))
.returning();
if (deletedClient) { if (deletedClient) {
rebuildClientAssociationsFromClient(deletedClient, primaryDb).catch( await rebuildClientAssociationsFromClient(deletedClient, trx);
(e) => { if (olm) {
logger.error( await sendTerminateClient(
`Failed to rebuild client-site associations after deleting OLM ${olmId}: ${e}` deletedClient.clientId,
); OlmErrorCodes.TERMINATED_DELETED,
olm.olmId
); // the olmId needs to be provided because it cant look it up after deletion
} }
); }
sendTerminateClient( });
deletedClient.clientId,
OlmErrorCodes.TERMINATED_DELETED,
olmId
).catch((e) => {
logger.error(
`Failed to send terminate message for client ${deletedClient?.clientId} after deleting OLM ${olmId}: ${e}`
);
});
}
return response(res, { return response(res, {
data: null, data: null,

View File

@@ -22,14 +22,14 @@ import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config"; import config from "@server/lib/config";
export const handleOlmRegisterMessage: MessageHandler = async (context) => { export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("[handleOlmRegisterMessage] Handling register olm message"); logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context; const { message, client: c, sendToClient } = context;
const olm = c as Olm; const olm = c as Olm;
const now = Math.floor(Date.now() / 1000); const now = Math.floor(Date.now() / 1000);
if (!olm) { if (!olm) {
logger.warn("[handleOlmRegisterMessage] Olm not found"); logger.warn("Olm not found");
return; return;
} }
@@ -46,19 +46,16 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} = message.data; } = message.data;
if (!olm.clientId) { if (!olm.clientId) {
logger.warn("[handleOlmRegisterMessage] Olm client ID not found"); logger.warn("Olm client ID not found");
sendOlmError(OlmErrorCodes.CLIENT_ID_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_ID_NOT_FOUND, olm.olmId);
return; return;
} }
logger.debug( logger.debug("Handling fingerprint insertion for olm register...", {
"[handleOlmRegisterMessage] Handling fingerprint insertion for olm register...", olmId: olm.olmId,
{ fingerprint,
olmId: olm.olmId, postures
fingerprint, });
postures
}
);
const isUserDevice = olm.userId !== null && olm.userId !== undefined; const isUserDevice = olm.userId !== null && olm.userId !== undefined;
@@ -88,17 +85,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.limit(1); .limit(1);
if (!client) { if (!client) {
logger.warn("[handleOlmRegisterMessage] Client not found", { logger.warn("Client ID not found");
clientId: olm.clientId
});
sendOlmError(OlmErrorCodes.CLIENT_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_NOT_FOUND, olm.olmId);
return; return;
} }
if (client.blocked) { if (client.blocked) {
logger.debug( logger.debug(
`[handleOlmRegisterMessage] Client ${client.clientId} is blocked. Ignoring register.`, `Client ${client.clientId} is blocked. Ignoring register.`
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId);
return; return;
@@ -106,8 +100,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (client.approvalState == "pending") { if (client.approvalState == "pending") {
logger.debug( logger.debug(
`[handleOlmRegisterMessage] Client ${client.clientId} approval is pending. Ignoring register.`, `Client ${client.clientId} approval is pending. Ignoring register.`
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId);
return; return;
@@ -135,18 +128,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.limit(1); .limit(1);
if (!org) { if (!org) {
logger.warn("[handleOlmRegisterMessage] Org not found", { logger.warn("Org not found");
orgId: client.orgId
});
sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId);
return; return;
} }
if (orgId) { if (orgId) {
if (!olm.userId) { if (!olm.userId) {
logger.warn("[handleOlmRegisterMessage] Olm has no user ID", { logger.warn("Olm has no user ID");
orgId: client.orgId
});
sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId);
return; return;
} }
@@ -154,18 +143,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
const { session: userSession, user } = const { session: userSession, user } =
await validateSessionToken(userToken); await validateSessionToken(userToken);
if (!userSession || !user) { if (!userSession || !user) {
logger.warn( logger.warn("Invalid user session for olm register");
"[handleOlmRegisterMessage] Invalid user session for olm register",
{ orgId: client.orgId }
);
sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId); sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId);
return; return;
} }
if (user.userId !== olm.userId) { if (user.userId !== olm.userId) {
logger.warn( logger.warn("User ID mismatch for olm register");
"[handleOlmRegisterMessage] User ID mismatch for olm register",
{ orgId: client.orgId }
);
sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId); sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId);
return; return;
} }
@@ -180,15 +163,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
sessionId // this is the user token passed in the message sessionId // this is the user token passed in the message
}); });
logger.debug("[handleOlmRegisterMessage] Policy check result", { logger.debug("Policy check result:", policyCheck);
orgId: client.orgId,
policyCheck
});
if (policyCheck?.error) { if (policyCheck?.error) {
logger.error( logger.error(
`[handleOlmRegisterMessage] Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}`, `Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}`
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId); sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
return; return;
@@ -196,8 +175,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (policyCheck.policies?.passwordAge?.compliant === false) { if (policyCheck.policies?.passwordAge?.compliant === false) {
logger.warn( logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant password age for org ${orgId}`, `Olm user ${olm.userId} has non-compliant password age for org ${orgId}`
{ orgId: client.orgId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED, OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED,
@@ -208,8 +186,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
policyCheck.policies?.maxSessionLength?.compliant === false policyCheck.policies?.maxSessionLength?.compliant === false
) { ) {
logger.warn( logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant session length for org ${orgId}`, `Olm user ${olm.userId} has non-compliant session length for org ${orgId}`
{ orgId: client.orgId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED, OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED,
@@ -218,8 +195,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return; return;
} else if (policyCheck.policies?.requiredTwoFactor === false) { } else if (policyCheck.policies?.requiredTwoFactor === false) {
logger.warn( logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}`, `Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}`
{ orgId: client.orgId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED, OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED,
@@ -228,8 +204,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return; return;
} else if (!policyCheck.allowed) { } else if (!policyCheck.allowed) {
logger.warn( logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`, `Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId); sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
return; return;
@@ -251,39 +226,29 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0; sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations // Prepare an array to store site configurations
logger.debug( logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
`[handleOlmRegisterMessage] Found ${sitesCount} sites for client ${client.clientId}`,
{ orgId: client.orgId }
);
let jitMode = false; let jitMode = false;
if (sitesCount > 250 && build == "saas") { if (sitesCount > 250 && build == "saas") {
// THIS IS THE MAX ON THE BUSINESS TIER // THIS IS THE MAX ON THE BUSINESS TIER
// we have too many sites // we have too many sites
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites // If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
logger.info( logger.info("Too many sites (%d), dropping into JIT mode", sitesCount);
`[handleOlmRegisterMessage] Too many sites (${sitesCount}), dropping into JIT mode`,
{ orgId: client.orgId }
);
jitMode = true; jitMode = true;
} }
logger.debug( logger.debug(
`[handleOlmRegisterMessage] Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`, `Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
{ orgId: client.orgId }
); );
if (!publicKey) { if (!publicKey) {
logger.warn("[handleOlmRegisterMessage] Public key not provided", { logger.warn("Public key not provided");
orgId: client.orgId
});
return; return;
} }
if (client.pubKey !== publicKey || client.archived) { if (client.pubKey !== publicKey || client.archived) {
logger.info( logger.info(
"[handleOlmRegisterMessage] Public key mismatch. Updating public key and clearing session info...", "Public key mismatch. Updating public key and clearing session info..."
{ orgId: client.orgId }
); );
// Update the client's public key // Update the client's public key
await db await db
@@ -309,13 +274,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// TODO: I still think there is a better way to do this rather than locking it out here but ??? // TODO: I still think there is a better way to do this rather than locking it out here but ???
if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) { if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) {
logger.warn( logger.warn(
`[handleOlmRegisterMessage] Client last hole punch is too old and we have sites to send; skipping this register. The client is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`, `Client last hole punch is too old and we have sites to send; skipping this register. The client is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`
{ orgId: client.orgId }
); );
return; return;
} }
// NOTE: its important that the client here is the old client and the public key is the new key // NOTE: its important that the client here is the old client and the public key is the new key
const siteConfigurations = await buildSiteConfigurationForOlmClient( const siteConfigurations = await buildSiteConfigurationForOlmClient(
client, client,
publicKey, publicKey,

View File

@@ -5,8 +5,7 @@ import {
clients, clients,
clientSiteResources, clientSiteResources,
siteResources, siteResources,
apiKeyOrg, apiKeyOrg
primaryDb
} from "@server/db"; } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -221,12 +220,8 @@ export async function batchAddClientToSiteResources(
siteResourceId: siteResource.siteResourceId siteResourceId: siteResource.siteResourceId
}); });
} }
});
rebuildClientAssociationsFromClient(client, primaryDb).catch((e) => { await rebuildClientAssociationsFromClient(client, trx);
logger.error(
`Failed to rebuild client associations after batch adding site resources for client ${clientId}: ${e}`
);
}); });
return response(res, { return response(res, {

View File

@@ -1,13 +1,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, orgs, primaryDb } from "@server/db"; import { db, orgs } from "@server/db";
import { import { roles, userInviteRoles, userInvites, userOrgs, users } from "@server/db";
roles,
userInviteRoles,
userInvites,
userOrgs,
users
} from "@server/db";
import { eq, and, inArray } from "drizzle-orm"; import { eq, and, inArray } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -152,7 +146,9 @@ export async function acceptInvite(
.from(userInviteRoles) .from(userInviteRoles)
.where(eq(userInviteRoles.inviteId, inviteId)); .where(eq(userInviteRoles.inviteId, inviteId));
const inviteRoleIds = [...new Set(inviteRoleRows.map((r) => r.roleId))]; const inviteRoleIds = [
...new Set(inviteRoleRows.map((r) => r.roleId))
];
if (inviteRoleIds.length === 0) { if (inviteRoleIds.length === 0) {
return next( return next(
createHttpError( createHttpError(
@@ -197,19 +193,13 @@ export async function acceptInvite(
.delete(userInvites) .delete(userInvites)
.where(eq(userInvites.inviteId, inviteId)); .where(eq(userInvites.inviteId, inviteId));
await calculateUserClientsForOrgs(existingUser[0].userId, trx);
logger.debug( logger.debug(
`User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}` `User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}`
); );
}); });
calculateUserClientsForOrgs(existingUser[0].userId, primaryDb).catch(
(e) => {
logger.error(
`Failed to calculate user clients after accepting invite for user ${existingUser[0].userId}: ${e}`
);
}
);
return response<AcceptInviteResponse>(res, { return response<AcceptInviteResponse>(res, {
data: { accepted: true, orgId: existingInvite.orgId }, data: { accepted: true, orgId: existingInvite.orgId },
success: true, success: true,

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import stoi from "@server/lib/stoi"; import stoi from "@server/lib/stoi";
import { clients, db, primaryDb, Client } from "@server/db"; import { clients, db } from "@server/db";
import { userOrgRoles, userOrgs, roles } from "@server/db"; import { userOrgRoles, userOrgs, roles } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -112,8 +112,6 @@ export async function addUserRoleLegacy(
); );
} }
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
.delete(userOrgRoles) .delete(userOrgRoles)
@@ -140,19 +138,11 @@ export async function addUserRoleLegacy(
) )
); );
orgClientsToRebuild = orgClients; for (const orgClient of orgClients) {
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after adding role: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: { ...existingUser, roleId }, data: { ...existingUser, roleId },
success: true, success: true,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, primaryDb } from "@server/db"; import { db } from "@server/db";
import { users } from "@server/db"; import { users } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -53,12 +53,8 @@ export async function adminRemoveUser(
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx.delete(users).where(eq(users.userId, userId)); await trx.delete(users).where(eq(users.userId, userId));
});
calculateUserClientsForOrgs(userId, primaryDb).catch((e) => { await calculateUserClientsForOrgs(userId, trx);
logger.error(
`Failed to calculate user clients after removing user ${userId}: ${e}`
);
}); });
return response(res, { return response(res, {

View File

@@ -6,7 +6,7 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { db, orgs, primaryDb } from "@server/db"; import { db, orgs } from "@server/db";
import { and, eq, inArray } from "drizzle-orm"; import { and, eq, inArray } from "drizzle-orm";
import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db"; import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db";
import { generateId } from "@server/auth/sessions/app"; import { generateId } from "@server/auth/sessions/app";
@@ -34,7 +34,8 @@ const bodySchema = z
roleId: z.number().int().positive().optional() roleId: z.number().int().positive().optional()
}) })
.refine( .refine(
(d) => (d.roleIds != null && d.roleIds.length > 0) || d.roleId != null, (d) =>
(d.roleIds != null && d.roleIds.length > 0) || d.roleId != null,
{ message: "roleIds or roleId is required", path: ["roleIds"] } { message: "roleIds or roleId is required", path: ["roleIds"] }
) )
.transform((data) => ({ .transform((data) => ({
@@ -99,14 +100,8 @@ export async function createOrgUser(
} }
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
const { const { username, email, name, type, idpId, roleIds: uniqueRoleIds } =
username, parsedBody.data;
email,
name,
type,
idpId,
roleIds: uniqueRoleIds
} = parsedBody.data;
if (build == "saas") { if (build == "saas") {
const usage = await usageService.getUsage(orgId, FeatureId.USERS); const usage = await usageService.getUsage(orgId, FeatureId.USERS);
@@ -237,7 +232,6 @@ export async function createOrgUser(
); );
} }
let userIdForClients: string | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const [existingUser] = await trx const [existingUser] = await trx
.select() .select()
@@ -276,7 +270,7 @@ export async function createOrgUser(
{ {
orgId, orgId,
userId: existingUser.userId, userId: existingUser.userId,
autoProvisioned: false autoProvisioned: false,
}, },
uniqueRoleIds, uniqueRoleIds,
trx trx
@@ -298,30 +292,20 @@ export async function createOrgUser(
}) })
.returning(); .returning();
await assignUserToOrg( await assignUserToOrg(
org, org,
{ {
orgId, orgId,
userId: newUser.userId, userId: newUser.userId,
autoProvisioned: false autoProvisioned: false,
}, },
uniqueRoleIds, uniqueRoleIds,
trx trx
); );
} }
userIdForClients = userId; await calculateUserClientsForOrgs(userId, trx);
}); });
if (userIdForClients) {
calculateUserClientsForOrgs(userIdForClients, primaryDb).catch(
(e) => {
logger.error(
`Failed to calculate user clients after creating org user: ${e}`
);
}
);
}
} else { } else {
return next( return next(
createHttpError(HttpCode.BAD_REQUEST, "User type is required") createHttpError(HttpCode.BAD_REQUEST, "User type is required")

View File

@@ -7,8 +7,7 @@ import {
siteResources, siteResources,
sites, sites,
UserOrg, UserOrg,
userSiteResources, userSiteResources
primaryDb
} from "@server/db"; } from "@server/db";
import { userOrgs, userResources, users, userSites } from "@server/db"; import { userOrgs, userResources, users, userSites } from "@server/db";
import { and, count, eq, exists, inArray } from "drizzle-orm"; import { and, count, eq, exists, inArray } from "drizzle-orm";
@@ -92,12 +91,25 @@ export async function removeUserOrg(
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await removeUserFromOrg(org, userId, trx); await removeUserFromOrg(org, userId, trx);
});
calculateUserClientsForOrgs(userId, primaryDb).catch((e) => { // if (build === "saas") {
logger.error( // const [rootUser] = await trx
`Failed to calculate user clients after removing user ${userId} from org ${orgId}: ${e}` // .select()
); // .from(users)
// .where(eq(users.userId, userId));
//
// const [leftInOrgs] = await trx
// .select({ count: count() })
// .from(userOrgs)
// .where(eq(userOrgs.userId, userId));
//
// // if the user is not an internal user and does not belong to any org, delete the entire user
// if (rootUser?.type !== UserType.Internal && !leftInOrgs.count) {
// await trx.delete(users).where(eq(users.userId, userId));
// }
// }
await calculateUserClientsForOrgs(userId, trx);
}); });
return response(res, { return response(res, {

View File

@@ -3,6 +3,8 @@ import { sql } from "drizzle-orm";
const version = "1.18.3"; const version = "1.18.3";
await migration();
export default async function migration() { export default async function migration() {
console.log(`Running setup script ${version}...`); console.log(`Running setup script ${version}...`);
@@ -44,7 +46,7 @@ export default async function migration() {
await db.execute(sql`BEGIN`); await db.execute(sql`BEGIN`);
await db.execute(sql` await db.execute(sql`
CREATE TABLE IF NOT EXISTS "trialNotifications" ( CREATE TABLE "trialNotifications" (
"notificationId" serial PRIMARY KEY NOT NULL, "notificationId" serial PRIMARY KEY NOT NULL,
"subscriptionId" varchar(255) NOT NULL, "subscriptionId" varchar(255) NOT NULL,
"notificationType" varchar(50) NOT NULL, "notificationType" varchar(50) NOT NULL,
@@ -52,6 +54,10 @@ export default async function migration() {
); );
`); `);
await db.execute(sql`
ALTER TABLE "trialNotifications" ADD CONSTRAINT "trialNotifications_subscriptionId_subscriptions_subscriptionId_fk" FOREIGN KEY ("subscriptionId") REFERENCES "public"."subscriptions"("subscriptionId") ON DELETE cascade ON UPDATE no action;
`);
await db.execute(sql`COMMIT`); await db.execute(sql`COMMIT`);
console.log("Migrated database"); console.log("Migrated database");
} catch (e) { } catch (e) {

View File

@@ -16,7 +16,7 @@ export default async function migration() {
db.transaction(() => { db.transaction(() => {
db.prepare( db.prepare(
` `
CREATE TABLE IF NOT EXISTS 'trialNotifications' ( CREATE TABLE 'trialNotifications' (
'notificationId' integer PRIMARY KEY AUTOINCREMENT NOT NULL, 'notificationId' integer PRIMARY KEY AUTOINCREMENT NOT NULL,
'subscriptionId' text NOT NULL, 'subscriptionId' text NOT NULL,
'notificationType' text NOT NULL, 'notificationType' text NOT NULL,

View File

@@ -55,9 +55,7 @@ export default async function ProxyResourcesPage(
pagination = responseData.pagination; pagination = responseData.pagination;
} catch (e) {} } catch (e) {}
const siteIdParam = parsePositiveInt( const siteIdParam = parsePositiveInt(searchParams.get("siteId") ?? undefined);
searchParams.get("siteId") ?? undefined
);
let initialFilterSite: { let initialFilterSite: {
siteId: number; siteId: number;
@@ -124,7 +122,6 @@ export default async function ProxyResourcesPage(
domainId: resource.domainId || undefined, domainId: resource.domainId || undefined,
fullDomain: resource.fullDomain ?? null, fullDomain: resource.fullDomain ?? null,
ssl: resource.ssl, ssl: resource.ssl,
wildcard: resource.wildcard,
targets: resource.targets?.map((target) => ({ targets: resource.targets?.map((target) => ({
targetId: target.targetId, targetId: target.targetId,
ip: target.ip, ip: target.ip,

View File

@@ -96,7 +96,6 @@ export type ResourceRow = {
targets?: TargetHealth[]; targets?: TargetHealth[];
health?: "healthy" | "degraded" | "unhealthy" | "unknown"; health?: "healthy" | "degraded" | "unhealthy" | "unknown";
sites: ResourceSiteRow[]; sites: ResourceSiteRow[];
wildcard?: boolean;
}; };
function StatusIcon({ function StatusIcon({
@@ -571,14 +570,10 @@ export default function ProxyResourcesTable({
/> />
) : null} ) : null}
<div className=""> <div className="">
{!resourceRow.wildcard ? ( <CopyToClipboard
<CopyToClipboard text={resourceRow.domain}
text={resourceRow.domain} isLink={true}
isLink={true} />
/>
) : (
<span>{resourceRow.domain}</span>
)}
</div> </div>
</div> </div>
); );