Compare commits

..

9 Commits
1.18.3 ... dev

Author SHA1 Message Date
Owen Schwartz
10f95896aa Merge pull request #3030 from fosrl/dev
1.18.3-s.2 fix
2026-05-07 20:08:05 -07:00
Owen
5b8994d143 Cange to use primaryDb 2026-05-07 20:07:06 -07:00
Owen
c46ef2fe9c Fix ts type issue 2026-05-07 20:03:48 -07:00
Owen Schwartz
4cd025dd91 Merge pull request #3029 from fosrl/dev
1.18.3-s.2
2026-05-07 17:44:35 -07:00
Owen
ce04ea9720 Fix not including today
Fixes #3028
2026-05-07 16:15:13 -07:00
Owen
a3ce382725 Pick up other domains in the sans field 2026-05-07 15:49:12 -07:00
Owen
4eb49e3e60 Make the rebuild long running function background 2026-05-07 15:40:34 -07:00
Owen
2a9481023a Dont show link when wildcard 2026-05-07 15:15:03 -07:00
Owen
8ed01372b8 Add org to logs 2026-05-07 15:14:44 -07:00
24 changed files with 478 additions and 295 deletions

View File

@@ -87,7 +87,7 @@ function createDb() {
export const db = createDb(); export const db = createDb();
export default db; export default db;
export const primaryDb = db.$primary; export const primaryDb = db.$primary as typeof db; // is this typeof a problem - techincally they are different types
export type Transaction = Parameters< export type Transaction = Parameters<
Parameters<(typeof db)["transaction"]>[0] Parameters<(typeof db)["transaction"]>[0]
>[0]; >[0];

View File

@@ -25,9 +25,9 @@ import { tierMatrix } from "./billing/tierMatrix";
export async function calculateUserClientsForOrgs( export async function calculateUserClientsForOrgs(
userId: string, userId: string,
trx?: Transaction trx: Transaction | typeof db = db
): Promise<void> { ): Promise<void> {
const execute = async (transaction: Transaction) => { const execute = async (transaction: Transaction | typeof db) => {
const orgCache = new Map<string, typeof orgs.$inferSelect | null>(); const orgCache = new Map<string, typeof orgs.$inferSelect | null>();
const adminRoleCache = new Map< const adminRoleCache = new Map<
string, string,
@@ -437,7 +437,7 @@ export async function calculateUserClientsForOrgs(
async function cleanupOrphanedClients( async function cleanupOrphanedClients(
userId: string, userId: string,
trx: Transaction, trx: Transaction | typeof db,
userOrgIds: string[] = [] userOrgIds: string[] = []
): Promise<void> { ): Promise<void> {
// Find all OLM clients for this user that should be deleted // Find all OLM clients for this user that should be deleted

View File

@@ -124,7 +124,7 @@ export function computeBuckets(
let totalDowntime = 0; let totalDowntime = 0;
for (let d = 0; d < days; d++) { for (let d = 0; d < days; d++) {
const dayStartSec = todayMidnightSec - (days - d) * 86400; const dayStartSec = todayMidnightSec - (days - 1 - d) * 86400;
const dayEndSec = dayStartSec + 86400; const dayEndSec = dayStartSec + 86400;
const dayEvents = events.filter( const dayEvents = events.filter(

View File

@@ -485,6 +485,133 @@ async function syncAcmeCertsFromHttp(endpoint: string): Promise<void> {
} }
} }
async function storeCertForDomain(
domain: string,
certPem: string,
keyPem: string,
validatedX509: crypto.X509Certificate
): Promise<void> {
const wildcard = domain.startsWith("*.");
const existing = await db
.select()
.from(certificates)
.where(eq(certificates.domain, domain))
.limit(1);
let oldCertPem: string | null = null;
let oldKeyPem: string | null = null;
if (existing.length > 0 && existing[0].certFile) {
try {
const storedCertPem = decrypt(
existing[0].certFile,
config.getRawConfig().server.secret!
);
const wildcardUnchanged = existing[0].wildcard === wildcard;
if (storedCertPem === certPem && wildcardUnchanged) {
return;
}
oldCertPem = storedCertPem;
if (existing[0].keyFile) {
try {
oldKeyPem = decrypt(
existing[0].keyFile,
config.getRawConfig().server.secret!
);
} catch (keyErr) {
logger.debug(
`acmeCertSync: could not decrypt stored key for ${domain}: ${keyErr}`
);
}
}
} catch (err) {
logger.debug(
`acmeCertSync: could not decrypt stored cert for ${domain}, will update: ${err}`
);
}
}
let expiresAt: number | null = null;
try {
expiresAt = Math.floor(
new Date(validatedX509.validTo).getTime() / 1000
);
} catch (err) {
logger.debug(
`acmeCertSync: could not parse cert expiry for ${domain}: ${err}`
);
}
const encryptedCert = encrypt(
certPem,
config.getRawConfig().server.secret!
);
const encryptedKey = encrypt(keyPem, config.getRawConfig().server.secret!);
const now = Math.floor(Date.now() / 1000);
const domainId = await findDomainId(domain);
if (domainId) {
logger.debug(
`acmeCertSync: resolved domainId "${domainId}" for cert domain "${domain}"`
);
} else {
logger.debug(
`acmeCertSync: no matching domain record found for cert domain "${domain}"`
);
}
if (existing.length > 0) {
logger.debug(
`acmeCertSync: updating existing certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db
.update(certificates)
.set({
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
updatedAt: now,
wildcard,
...(domainId !== null && { domainId })
})
.where(eq(certificates.domain, domain));
logger.debug(
`acmeCertSync: updated certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await pushCertUpdateToAffectedNewts(
domain,
domainId,
oldCertPem,
oldKeyPem
);
} else {
logger.debug(
`acmeCertSync: inserting new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db.insert(certificates).values({
domain,
domainId,
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
createdAt: now,
updatedAt: now,
wildcard
});
logger.debug(
`acmeCertSync: inserted new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await pushCertUpdateToAffectedNewts(domain, domainId, null, null);
}
}
function findAcmeJsonFiles(dirPath: string): string[] { function findAcmeJsonFiles(dirPath: string): string[] {
const results: string[] = []; const results: string[] = [];
let entries: fs.Dirent[]; let entries: fs.Dirent[];
@@ -575,18 +702,16 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
} }
for (const cert of allCerts) { for (const cert of allCerts) {
const domain = cert?.domain?.main; const mainDomain = cert?.domain?.main;
if (!domain || typeof domain !== "string") { if (!mainDomain || typeof mainDomain !== "string") {
logger.debug(`acmeCertSync: skipping cert with missing domain`); logger.debug(`acmeCertSync: skipping cert with missing domain`);
continue; continue;
} }
const { wildcard } = detectWildcard(domain, cert.domain?.sans);
if (!cert.certificate || !cert.key) { if (!cert.certificate || !cert.key) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - empty certificate or key field` `acmeCertSync: skipping cert for ${mainDomain} - empty certificate or key field`
); );
continue; continue;
} }
@@ -598,14 +723,14 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
keyPem = Buffer.from(cert.key, "base64").toString("utf8"); keyPem = Buffer.from(cert.key, "base64").toString("utf8");
} catch (err) { } catch (err) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - failed to base64-decode cert/key: ${err}` `acmeCertSync: skipping cert for ${mainDomain} - failed to base64-decode cert/key: ${err}`
); );
continue; continue;
} }
if (!certPem.trim() || !keyPem.trim()) { if (!certPem.trim() || !keyPem.trim()) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - blank PEM after base64 decode` `acmeCertSync: skipping cert for ${mainDomain} - blank PEM after base64 decode`
); );
continue; continue;
} }
@@ -616,7 +741,7 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
const firstCertPemForValidation = extractFirstCert(certPem); const firstCertPemForValidation = extractFirstCert(certPem);
if (!firstCertPemForValidation) { if (!firstCertPemForValidation) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - no PEM certificate block found` `acmeCertSync: skipping cert for ${mainDomain} - no PEM certificate block found`
); );
continue; continue;
} }
@@ -628,7 +753,7 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
); );
} catch (err) { } catch (err) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - invalid X.509 certificate: ${err}` `acmeCertSync: skipping cert for ${mainDomain} - invalid X.509 certificate: ${err}`
); );
continue; continue;
} }
@@ -638,139 +763,40 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
crypto.createPrivateKey(keyPem); crypto.createPrivateKey(keyPem);
} catch (err) { } catch (err) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - invalid private key: ${err}` `acmeCertSync: skipping cert for ${mainDomain} - invalid private key: ${err}`
); );
continue; continue;
} }
// Check if cert already exists in DB // Collect all domains covered by this cert: main + every SAN.
const existing = await db // Each domain gets its own row in the certificates table so that
.select() // lookups by any hostname on the cert succeed independently.
.from(certificates) const allDomains = new Set<string>([mainDomain]);
.where(and(eq(certificates.domain, domain))) if (Array.isArray(cert.domain?.sans)) {
.limit(1); for (const san of cert.domain.sans) {
if (typeof san === "string" && san.trim()) {
let oldCertPem: string | null = null; allDomains.add(san.trim());
let oldKeyPem: string | null = null;
if (existing.length > 0 && existing[0].certFile) {
try {
const storedCertPem = decrypt(
existing[0].certFile,
config.getRawConfig().server.secret!
);
const wildcardUnchanged = existing[0].wildcard === wildcard;
if (storedCertPem === certPem && wildcardUnchanged) {
// logger.debug(
// `acmeCertSync: cert for ${domain} is unchanged, skipping`
// );
continue;
} }
// Cert has changed; capture old values so we can send a correct }
// update message to the newt after the DB write. }
oldCertPem = storedCertPem;
if (existing[0].keyFile) {
try {
oldKeyPem = decrypt(
existing[0].keyFile,
config.getRawConfig().server.secret!
);
} catch (keyErr) {
logger.debug( logger.debug(
`acmeCertSync: could not decrypt stored key for ${domain}: ${keyErr}` `acmeCertSync: cert for ${mainDomain} covers ${allDomains.size} domain(s): ${[...allDomains].join(", ")}`
); );
}
}
} catch (err) {
// Decryption failure means we should proceed with the update
logger.debug(
`acmeCertSync: could not decrypt stored cert for ${domain}, will update: ${err}`
);
}
}
// Parse cert expiry from the validated X.509 certificate for (const domain of allDomains) {
let expiresAt: number | null = null;
try { try {
expiresAt = Math.floor( await storeCertForDomain(
new Date(validatedX509.validTo).getTime() / 1000 domain,
);
} catch (err) {
logger.debug(
`acmeCertSync: could not parse cert expiry for ${domain}: ${err}`
);
}
const encryptedCert = encrypt(
certPem, certPem,
config.getRawConfig().server.secret!
);
const encryptedKey = encrypt(
keyPem, keyPem,
config.getRawConfig().server.secret! validatedX509
); );
const now = Math.floor(Date.now() / 1000); } catch (err) {
logger.error(
const domainId = await findDomainId(domain); `acmeCertSync: error storing cert for domain "${domain}": ${err}`
if (domainId) {
logger.debug(
`acmeCertSync: resolved domainId "${domainId}" for cert domain "${domain}"`
);
} else {
logger.debug(
`acmeCertSync: no matching domain record found for cert domain "${domain}"`
); );
} }
if (existing.length > 0) {
logger.debug(
`acmeCertSync: updating existing certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db
.update(certificates)
.set({
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
updatedAt: now,
wildcard,
...(domainId !== null && { domainId })
})
.where(eq(certificates.domain, domain));
logger.debug(
`acmeCertSync: updated certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await pushCertUpdateToAffectedNewts(
domain,
domainId,
oldCertPem,
oldKeyPem
);
} else {
logger.debug(
`acmeCertSync: inserting new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db.insert(certificates).values({
domain,
domainId,
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
createdAt: now,
updatedAt: now,
wildcard
});
logger.debug(
`acmeCertSync: inserted new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
// For a brand-new cert, push to any SSL resources that were waiting for it
await pushCertUpdateToAffectedNewts(domain, domainId, null, null);
} }
} }
} }

View File

@@ -14,7 +14,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import stoi from "@server/lib/stoi"; import stoi from "@server/lib/stoi";
import { clients, db } from "@server/db"; import { clients, db, primaryDb, Client } from "@server/db";
import { userOrgRoles, userOrgs, roles } from "@server/db"; import { userOrgRoles, userOrgs, roles } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -122,8 +122,12 @@ export async function addUserRole(
); );
} }
let newUserRole: { userId: string; orgId: string; roleId: number } | null = let newUserRole: {
null; userId: string;
orgId: string;
roleId: number;
} | null = null;
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const inserted = await trx const inserted = await trx
.insert(userOrgRoles) .insert(userOrgRoles)
@@ -149,11 +153,19 @@ export async function addUserRole(
) )
); );
for (const orgClient of orgClients) { orgClientsToRebuild = orgClients;
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after adding role: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: newUserRole ?? { userId, orgId: role.orgId, roleId }, data: newUserRole ?? { userId, orgId: role.orgId, roleId },
success: true, success: true,

View File

@@ -14,7 +14,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import stoi from "@server/lib/stoi"; import stoi from "@server/lib/stoi";
import { db } from "@server/db"; import { db, primaryDb, Client } from "@server/db";
import { userOrgRoles, userOrgs, roles, clients } from "@server/db"; import { userOrgRoles, userOrgs, roles, clients } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -129,6 +129,7 @@ export async function removeUserRole(
} }
} }
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
.delete(userOrgRoles) .delete(userOrgRoles)
@@ -150,11 +151,19 @@ export async function removeUserRole(
) )
); );
for (const orgClient of orgClients) { orgClientsToRebuild = orgClients;
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after removing role: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: { userId, orgId: role.orgId, roleId }, data: { userId, orgId: role.orgId, roleId },
success: true, success: true,

View File

@@ -13,7 +13,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { clients, db } from "@server/db"; import { clients, db, primaryDb, Client } from "@server/db";
import { userOrgRoles, userOrgs, roles } from "@server/db"; import { userOrgRoles, userOrgs, roles } from "@server/db";
import { eq, and, inArray } from "drizzle-orm"; import { eq, and, inArray } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -115,6 +115,7 @@ export async function setUserOrgRoles(
); );
} }
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
.delete(userOrgRoles) .delete(userOrgRoles)
@@ -142,11 +143,19 @@ export async function setUserOrgRoles(
and(eq(clients.userId, userId), eq(clients.orgId, orgId)) and(eq(clients.userId, userId), eq(clients.orgId, orgId))
); );
for (const orgClient of orgClients) { orgClientsToRebuild = orgClients;
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after setting roles: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: { userId, orgId, roleIds: uniqueRoleIds }, data: { userId, orgId, roleIds: uniqueRoleIds },
success: true, success: true,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, orgs, userOrgs, users } from "@server/db"; import { db, orgs, userOrgs, users, primaryDb } from "@server/db";
import { eq, and, inArray, not } from "drizzle-orm"; import { eq, and, inArray, not } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -218,13 +218,18 @@ export async function deleteMyAccount(
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx.delete(users).where(eq(users.userId, userId)); await trx.delete(users).where(eq(users.userId, userId));
await calculateUserClientsForOrgs(userId, trx);
// loop through the other orgs and decrement the count // loop through the other orgs and decrement the count
for (const userOrg of otherOrgsTheUserWasIn) { for (const userOrg of otherOrgsTheUserWasIn) {
await usageService.add(userOrg.orgId, FeatureId.USERS, -1, trx); await usageService.add(userOrg.orgId, FeatureId.USERS, -1, trx);
} }
}); });
calculateUserClientsForOrgs(userId, primaryDb).catch((e) => {
logger.error(
`Failed to calculate user clients after deleting account for user ${userId}: ${e}`
);
});
try { try {
await invalidateSession(session.sessionId); await invalidateSession(session.sessionId);
} catch (error) { } catch (error) {

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db, primaryDb } from "@server/db";
import { import {
roles, roles,
Client, Client,
@@ -92,7 +92,10 @@ export async function createClient(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) { if (
req.user &&
(!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)
) {
return next( return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role") createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
); );
@@ -198,7 +201,10 @@ export async function createClient(
if (!randomExitNode) { if (!randomExitNode) {
return next( return next(
createHttpError(HttpCode.NOT_FOUND, `No exit nodes available. ${build == "saas" ? "Please contact support." : "You need to install gerbil to use the clients."}`) createHttpError(
HttpCode.NOT_FOUND,
`No exit nodes available. ${build == "saas" ? "Please contact support." : "You need to install gerbil to use the clients."}`
)
); );
} }
@@ -256,10 +262,18 @@ export async function createClient(
clientId: newClient.clientId, clientId: newClient.clientId,
dateCreated: moment().toISOString() dateCreated: moment().toISOString()
}); });
await rebuildClientAssociationsFromClient(newClient, trx);
}); });
if (newClient) {
rebuildClientAssociationsFromClient(newClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations after creating client: ${e}`
);
}
);
}
return response<CreateClientResponse>(res, { return response<CreateClientResponse>(res, {
data: newClient, data: newClient,
success: true, success: true,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db, primaryDb } from "@server/db";
import { import {
roles, roles,
Client, Client,
@@ -237,10 +237,18 @@ export async function createUserClient(
userId, userId,
clientId: newClient.clientId clientId: newClient.clientId
}); });
await rebuildClientAssociationsFromClient(newClient, trx);
}); });
if (newClient) {
rebuildClientAssociationsFromClient(newClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations after creating user client: ${e}`
);
}
);
}
return response<CreateClientAndOlmResponse>(res, { return response<CreateClientAndOlmResponse>(res, {
data: newClient, data: newClient,
success: true, success: true,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, olms } from "@server/db"; import { db, olms, primaryDb, Client, Olm } from "@server/db";
import { clients, clientSitesAssociationsCache } from "@server/db"; import { clients, clientSitesAssociationsCache } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -71,14 +71,17 @@ export async function deleteClient(
); );
} }
let deletedClient: Client | undefined;
let olm: Olm | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// Then delete the client itself // Then delete the client itself
const [deletedClient] = await trx [deletedClient] = await trx
.delete(clients) .delete(clients)
.where(eq(clients.clientId, clientId)) .where(eq(clients.clientId, clientId))
.returning(); .returning();
const [olm] = await trx [olm] = await trx
.select() .select()
.from(olms) .from(olms)
.where(eq(olms.clientId, clientId)) .where(eq(olms.clientId, clientId))
@@ -88,14 +91,29 @@ export async function deleteClient(
if (!client.userId && client.olmId) { if (!client.userId && client.olmId) {
await trx.delete(olms).where(eq(olms.olmId, client.olmId)); await trx.delete(olms).where(eq(olms.olmId, client.olmId));
} }
await rebuildClientAssociationsFromClient(deletedClient, trx);
if (olm) {
await sendTerminateClient(deletedClient.clientId, OlmErrorCodes.TERMINATED_DELETED, olm.olmId); // the olmId needs to be provided because it cant look it up after deletion
}
}); });
if (deletedClient) {
rebuildClientAssociationsFromClient(deletedClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations after deleting client ${clientId}: ${e}`
);
}
);
if (olm) {
sendTerminateClient(
deletedClient.clientId,
OlmErrorCodes.TERMINATED_DELETED,
olm.olmId
).catch((e) => {
logger.error(
`Failed to send terminate message for client ${deletedClient?.clientId} after deleting client ${clientId}: ${e}`
);
});
}
}
return response(res, { return response(res, {
data: null, data: null,
success: true, success: true,

View File

@@ -1,5 +1,5 @@
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import { db, olms } from "@server/db"; import { db, olms, primaryDb } from "@server/db";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { z } from "zod"; import { z } from "zod";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
@@ -81,8 +81,7 @@ export async function createUserOlm(
const secretHash = await hashPassword(secret); const secretHash = await hashPassword(secret);
await db.transaction(async (trx) => { await db.insert(olms).values({
await trx.insert(olms).values({
olmId: olmId, olmId: olmId,
userId, userId,
name, name,
@@ -90,7 +89,11 @@ export async function createUserOlm(
dateCreated: moment().toISOString() dateCreated: moment().toISOString()
}); });
await calculateUserClientsForOrgs(userId, trx); calculateUserClientsForOrgs(userId, primaryDb).catch((e) => {
console.error(
"Error calculating user clients after creating olm:",
e
);
}); });
return response<CreateOlmResponse>(res, { return response<CreateOlmResponse>(res, {

View File

@@ -1,5 +1,5 @@
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import { Client, db } from "@server/db"; import { Client, db, Olm, primaryDb } from "@server/db";
import { olms, clients, clientSitesAssociationsCache } from "@server/db"; import { olms, clients, clientSitesAssociationsCache } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -49,6 +49,7 @@ export async function deleteUserOlm(
const { olmId } = parsedParams.data; const { olmId } = parsedParams.data;
let deletedClient: Client | undefined;
// Delete associated clients and the OLM in a transaction // Delete associated clients and the OLM in a transaction
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// Find all clients associated with this OLM // Find all clients associated with this OLM
@@ -57,7 +58,6 @@ export async function deleteUserOlm(
.from(clients) .from(clients)
.where(eq(clients.olmId, olmId)); .where(eq(clients.olmId, olmId));
let deletedClient: Client | null = null;
// Delete all associated clients // Delete all associated clients
if (associatedClients.length > 0) { if (associatedClients.length > 0) {
[deletedClient] = await trx [deletedClient] = await trx
@@ -67,22 +67,27 @@ export async function deleteUserOlm(
} }
// Finally, delete the OLM itself // Finally, delete the OLM itself
const [olm] = await trx await trx.delete(olms).where(eq(olms.olmId, olmId)).returning();
.delete(olms) });
.where(eq(olms.olmId, olmId))
.returning();
if (deletedClient) { if (deletedClient) {
await rebuildClientAssociationsFromClient(deletedClient, trx); rebuildClientAssociationsFromClient(deletedClient, primaryDb).catch(
if (olm) { (e) => {
await sendTerminateClient( logger.error(
`Failed to rebuild client-site associations after deleting OLM ${olmId}: ${e}`
);
}
);
sendTerminateClient(
deletedClient.clientId, deletedClient.clientId,
OlmErrorCodes.TERMINATED_DELETED, OlmErrorCodes.TERMINATED_DELETED,
olm.olmId olmId
); // the olmId needs to be provided because it cant look it up after deletion ).catch((e) => {
} logger.error(
} `Failed to send terminate message for client ${deletedClient?.clientId} after deleting OLM ${olmId}: ${e}`
);
}); });
}
return response(res, { return response(res, {
data: null, data: null,

View File

@@ -22,14 +22,14 @@ import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config"; import config from "@server/lib/config";
export const handleOlmRegisterMessage: MessageHandler = async (context) => { export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!"); logger.info("[handleOlmRegisterMessage] Handling register olm message");
const { message, client: c, sendToClient } = context; const { message, client: c, sendToClient } = context;
const olm = c as Olm; const olm = c as Olm;
const now = Math.floor(Date.now() / 1000); const now = Math.floor(Date.now() / 1000);
if (!olm) { if (!olm) {
logger.warn("Olm not found"); logger.warn("[handleOlmRegisterMessage] Olm not found");
return; return;
} }
@@ -46,16 +46,19 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} = message.data; } = message.data;
if (!olm.clientId) { if (!olm.clientId) {
logger.warn("Olm client ID not found"); logger.warn("[handleOlmRegisterMessage] Olm client ID not found");
sendOlmError(OlmErrorCodes.CLIENT_ID_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_ID_NOT_FOUND, olm.olmId);
return; return;
} }
logger.debug("Handling fingerprint insertion for olm register...", { logger.debug(
"[handleOlmRegisterMessage] Handling fingerprint insertion for olm register...",
{
olmId: olm.olmId, olmId: olm.olmId,
fingerprint, fingerprint,
postures postures
}); }
);
const isUserDevice = olm.userId !== null && olm.userId !== undefined; const isUserDevice = olm.userId !== null && olm.userId !== undefined;
@@ -85,14 +88,17 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.limit(1); .limit(1);
if (!client) { if (!client) {
logger.warn("Client ID not found"); logger.warn("[handleOlmRegisterMessage] Client not found", {
clientId: olm.clientId
});
sendOlmError(OlmErrorCodes.CLIENT_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_NOT_FOUND, olm.olmId);
return; return;
} }
if (client.blocked) { if (client.blocked) {
logger.debug( logger.debug(
`Client ${client.clientId} is blocked. Ignoring register.` `[handleOlmRegisterMessage] Client ${client.clientId} is blocked. Ignoring register.`,
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId);
return; return;
@@ -100,7 +106,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (client.approvalState == "pending") { if (client.approvalState == "pending") {
logger.debug( logger.debug(
`Client ${client.clientId} approval is pending. Ignoring register.` `[handleOlmRegisterMessage] Client ${client.clientId} approval is pending. Ignoring register.`,
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId);
return; return;
@@ -128,14 +135,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.limit(1); .limit(1);
if (!org) { if (!org) {
logger.warn("Org not found"); logger.warn("[handleOlmRegisterMessage] Org not found", {
orgId: client.orgId
});
sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId);
return; return;
} }
if (orgId) { if (orgId) {
if (!olm.userId) { if (!olm.userId) {
logger.warn("Olm has no user ID"); logger.warn("[handleOlmRegisterMessage] Olm has no user ID", {
orgId: client.orgId
});
sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId);
return; return;
} }
@@ -143,12 +154,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
const { session: userSession, user } = const { session: userSession, user } =
await validateSessionToken(userToken); await validateSessionToken(userToken);
if (!userSession || !user) { if (!userSession || !user) {
logger.warn("Invalid user session for olm register"); logger.warn(
"[handleOlmRegisterMessage] Invalid user session for olm register",
{ orgId: client.orgId }
);
sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId); sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId);
return; return;
} }
if (user.userId !== olm.userId) { if (user.userId !== olm.userId) {
logger.warn("User ID mismatch for olm register"); logger.warn(
"[handleOlmRegisterMessage] User ID mismatch for olm register",
{ orgId: client.orgId }
);
sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId); sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId);
return; return;
} }
@@ -163,11 +180,15 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
sessionId // this is the user token passed in the message sessionId // this is the user token passed in the message
}); });
logger.debug("Policy check result:", policyCheck); logger.debug("[handleOlmRegisterMessage] Policy check result", {
orgId: client.orgId,
policyCheck
});
if (policyCheck?.error) { if (policyCheck?.error) {
logger.error( logger.error(
`Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}` `[handleOlmRegisterMessage] Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}`,
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId); sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
return; return;
@@ -175,7 +196,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (policyCheck.policies?.passwordAge?.compliant === false) { if (policyCheck.policies?.passwordAge?.compliant === false) {
logger.warn( logger.warn(
`Olm user ${olm.userId} has non-compliant password age for org ${orgId}` `[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant password age for org ${orgId}`,
{ orgId: client.orgId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED, OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED,
@@ -186,7 +208,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
policyCheck.policies?.maxSessionLength?.compliant === false policyCheck.policies?.maxSessionLength?.compliant === false
) { ) {
logger.warn( logger.warn(
`Olm user ${olm.userId} has non-compliant session length for org ${orgId}` `[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant session length for org ${orgId}`,
{ orgId: client.orgId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED, OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED,
@@ -195,7 +218,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return; return;
} else if (policyCheck.policies?.requiredTwoFactor === false) { } else if (policyCheck.policies?.requiredTwoFactor === false) {
logger.warn( logger.warn(
`Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}` `[handleOlmRegisterMessage] Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}`,
{ orgId: client.orgId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED, OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED,
@@ -204,7 +228,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return; return;
} else if (!policyCheck.allowed) { } else if (!policyCheck.allowed) {
logger.warn( logger.warn(
`Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}` `[handleOlmRegisterMessage] Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`,
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId); sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
return; return;
@@ -226,29 +251,39 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0; sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations // Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`); logger.debug(
`[handleOlmRegisterMessage] Found ${sitesCount} sites for client ${client.clientId}`,
{ orgId: client.orgId }
);
let jitMode = false; let jitMode = false;
if (sitesCount > 250 && build == "saas") { if (sitesCount > 250 && build == "saas") {
// THIS IS THE MAX ON THE BUSINESS TIER // THIS IS THE MAX ON THE BUSINESS TIER
// we have too many sites // we have too many sites
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites // If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
logger.info("Too many sites (%d), dropping into JIT mode", sitesCount); logger.info(
`[handleOlmRegisterMessage] Too many sites (${sitesCount}), dropping into JIT mode`,
{ orgId: client.orgId }
);
jitMode = true; jitMode = true;
} }
logger.debug( logger.debug(
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}` `[handleOlmRegisterMessage] Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`,
{ orgId: client.orgId }
); );
if (!publicKey) { if (!publicKey) {
logger.warn("Public key not provided"); logger.warn("[handleOlmRegisterMessage] Public key not provided", {
orgId: client.orgId
});
return; return;
} }
if (client.pubKey !== publicKey || client.archived) { if (client.pubKey !== publicKey || client.archived) {
logger.info( logger.info(
"Public key mismatch. Updating public key and clearing session info..." "[handleOlmRegisterMessage] Public key mismatch. Updating public key and clearing session info...",
{ orgId: client.orgId }
); );
// Update the client's public key // Update the client's public key
await db await db
@@ -274,7 +309,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// TODO: I still think there is a better way to do this rather than locking it out here but ??? // TODO: I still think there is a better way to do this rather than locking it out here but ???
if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) { if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) {
logger.warn( logger.warn(
`Client last hole punch is too old and we have sites to send; skipping this register. The client is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?` `[handleOlmRegisterMessage] Client last hole punch is too old and we have sites to send; skipping this register. The client is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`,
{ orgId: client.orgId }
); );
return; return;
} }

View File

@@ -5,7 +5,8 @@ import {
clients, clients,
clientSiteResources, clientSiteResources,
siteResources, siteResources,
apiKeyOrg apiKeyOrg,
primaryDb
} from "@server/db"; } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -220,8 +221,12 @@ export async function batchAddClientToSiteResources(
siteResourceId: siteResource.siteResourceId siteResourceId: siteResource.siteResourceId
}); });
} }
});
await rebuildClientAssociationsFromClient(client, trx); rebuildClientAssociationsFromClient(client, primaryDb).catch((e) => {
logger.error(
`Failed to rebuild client associations after batch adding site resources for client ${clientId}: ${e}`
);
}); });
return response(res, { return response(res, {

View File

@@ -10,7 +10,8 @@ import {
SiteResource, SiteResource,
siteResources, siteResources,
sites, sites,
userSiteResources userSiteResources,
primaryDb
} from "@server/db"; } from "@server/db";
import { getUniqueSiteResourceName } from "@server/db/names"; import { getUniqueSiteResourceName } from "@server/db/names";
import { import {
@@ -519,12 +520,10 @@ export async function createSiteResource(
// own transaction so it always executes on the primary — avoiding any // own transaction so it always executes on the primary — avoiding any
// replica-lag issues while still allowing the HTTP response to return // replica-lag issues while still allowing the HTTP response to return
// early. // early.
db.transaction(async (trx) => { rebuildClientAssociationsFromSiteResource(
await rebuildClientAssociationsFromSiteResource(
newSiteResource!, newSiteResource!,
trx primaryDb
); ).catch((err) => {
}).catch((err) => {
logger.error( logger.error(
`Error rebuilding client associations for site resource ${newSiteResource!.siteResourceId}:`, `Error rebuilding client associations for site resource ${newSiteResource!.siteResourceId}:`,
err err

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, newts, sites } from "@server/db"; import { db, newts, primaryDb, sites } from "@server/db";
import { siteResources } from "@server/db"; import { siteResources } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -73,12 +73,10 @@ export async function deleteSiteResource(
// own transaction so it always executes on the primary — avoiding any // own transaction so it always executes on the primary — avoiding any
// replica-lag issues while still allowing the HTTP response to return // replica-lag issues while still allowing the HTTP response to return
// early. // early.
db.transaction(async (trx) => { rebuildClientAssociationsFromSiteResource(
await rebuildClientAssociationsFromSiteResource(
removedSiteResource, removedSiteResource,
trx primaryDb
); ).catch((err) => {
}).catch((err) => {
logger.error( logger.error(
`Error rebuilding client associations for site resource ${removedSiteResource!.siteResourceId}:`, `Error rebuilding client associations for site resource ${removedSiteResource!.siteResourceId}:`,
err err

View File

@@ -1,7 +1,13 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, orgs } from "@server/db"; import { db, orgs, primaryDb } from "@server/db";
import { roles, userInviteRoles, userInvites, userOrgs, users } from "@server/db"; import {
roles,
userInviteRoles,
userInvites,
userOrgs,
users
} from "@server/db";
import { eq, and, inArray } from "drizzle-orm"; import { eq, and, inArray } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -146,9 +152,7 @@ export async function acceptInvite(
.from(userInviteRoles) .from(userInviteRoles)
.where(eq(userInviteRoles.inviteId, inviteId)); .where(eq(userInviteRoles.inviteId, inviteId));
const inviteRoleIds = [ const inviteRoleIds = [...new Set(inviteRoleRows.map((r) => r.roleId))];
...new Set(inviteRoleRows.map((r) => r.roleId))
];
if (inviteRoleIds.length === 0) { if (inviteRoleIds.length === 0) {
return next( return next(
createHttpError( createHttpError(
@@ -193,13 +197,19 @@ export async function acceptInvite(
.delete(userInvites) .delete(userInvites)
.where(eq(userInvites.inviteId, inviteId)); .where(eq(userInvites.inviteId, inviteId));
await calculateUserClientsForOrgs(existingUser[0].userId, trx);
logger.debug( logger.debug(
`User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}` `User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}`
); );
}); });
calculateUserClientsForOrgs(existingUser[0].userId, primaryDb).catch(
(e) => {
logger.error(
`Failed to calculate user clients after accepting invite for user ${existingUser[0].userId}: ${e}`
);
}
);
return response<AcceptInviteResponse>(res, { return response<AcceptInviteResponse>(res, {
data: { accepted: true, orgId: existingInvite.orgId }, data: { accepted: true, orgId: existingInvite.orgId },
success: true, success: true,

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import stoi from "@server/lib/stoi"; import stoi from "@server/lib/stoi";
import { clients, db } from "@server/db"; import { clients, db, primaryDb, Client } from "@server/db";
import { userOrgRoles, userOrgs, roles } from "@server/db"; import { userOrgRoles, userOrgs, roles } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -112,6 +112,8 @@ export async function addUserRoleLegacy(
); );
} }
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
.delete(userOrgRoles) .delete(userOrgRoles)
@@ -138,11 +140,19 @@ export async function addUserRoleLegacy(
) )
); );
for (const orgClient of orgClients) { orgClientsToRebuild = orgClients;
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after adding role: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: { ...existingUser, roleId }, data: { ...existingUser, roleId },
success: true, success: true,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db, primaryDb } from "@server/db";
import { users } from "@server/db"; import { users } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -53,8 +53,12 @@ export async function adminRemoveUser(
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx.delete(users).where(eq(users.userId, userId)); await trx.delete(users).where(eq(users.userId, userId));
});
await calculateUserClientsForOrgs(userId, trx); calculateUserClientsForOrgs(userId, primaryDb).catch((e) => {
logger.error(
`Failed to calculate user clients after removing user ${userId}: ${e}`
);
}); });
return response(res, { return response(res, {

View File

@@ -6,7 +6,7 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { db, orgs } from "@server/db"; import { db, orgs, primaryDb } from "@server/db";
import { and, eq, inArray } from "drizzle-orm"; import { and, eq, inArray } from "drizzle-orm";
import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db"; import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db";
import { generateId } from "@server/auth/sessions/app"; import { generateId } from "@server/auth/sessions/app";
@@ -34,8 +34,7 @@ const bodySchema = z
roleId: z.number().int().positive().optional() roleId: z.number().int().positive().optional()
}) })
.refine( .refine(
(d) => (d) => (d.roleIds != null && d.roleIds.length > 0) || d.roleId != null,
(d.roleIds != null && d.roleIds.length > 0) || d.roleId != null,
{ message: "roleIds or roleId is required", path: ["roleIds"] } { message: "roleIds or roleId is required", path: ["roleIds"] }
) )
.transform((data) => ({ .transform((data) => ({
@@ -100,8 +99,14 @@ export async function createOrgUser(
} }
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
const { username, email, name, type, idpId, roleIds: uniqueRoleIds } = const {
parsedBody.data; username,
email,
name,
type,
idpId,
roleIds: uniqueRoleIds
} = parsedBody.data;
if (build == "saas") { if (build == "saas") {
const usage = await usageService.getUsage(orgId, FeatureId.USERS); const usage = await usageService.getUsage(orgId, FeatureId.USERS);
@@ -232,6 +237,7 @@ export async function createOrgUser(
); );
} }
let userIdForClients: string | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const [existingUser] = await trx const [existingUser] = await trx
.select() .select()
@@ -270,7 +276,7 @@ export async function createOrgUser(
{ {
orgId, orgId,
userId: existingUser.userId, userId: existingUser.userId,
autoProvisioned: false, autoProvisioned: false
}, },
uniqueRoleIds, uniqueRoleIds,
trx trx
@@ -297,15 +303,25 @@ export async function createOrgUser(
{ {
orgId, orgId,
userId: newUser.userId, userId: newUser.userId,
autoProvisioned: false, autoProvisioned: false
}, },
uniqueRoleIds, uniqueRoleIds,
trx trx
); );
} }
await calculateUserClientsForOrgs(userId, trx); userIdForClients = userId;
}); });
if (userIdForClients) {
calculateUserClientsForOrgs(userIdForClients, primaryDb).catch(
(e) => {
logger.error(
`Failed to calculate user clients after creating org user: ${e}`
);
}
);
}
} else { } else {
return next( return next(
createHttpError(HttpCode.BAD_REQUEST, "User type is required") createHttpError(HttpCode.BAD_REQUEST, "User type is required")

View File

@@ -7,7 +7,8 @@ import {
siteResources, siteResources,
sites, sites,
UserOrg, UserOrg,
userSiteResources userSiteResources,
primaryDb
} from "@server/db"; } from "@server/db";
import { userOrgs, userResources, users, userSites } from "@server/db"; import { userOrgs, userResources, users, userSites } from "@server/db";
import { and, count, eq, exists, inArray } from "drizzle-orm"; import { and, count, eq, exists, inArray } from "drizzle-orm";
@@ -91,25 +92,12 @@ export async function removeUserOrg(
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await removeUserFromOrg(org, userId, trx); await removeUserFromOrg(org, userId, trx);
});
// if (build === "saas") { calculateUserClientsForOrgs(userId, primaryDb).catch((e) => {
// const [rootUser] = await trx logger.error(
// .select() `Failed to calculate user clients after removing user ${userId} from org ${orgId}: ${e}`
// .from(users) );
// .where(eq(users.userId, userId));
//
// const [leftInOrgs] = await trx
// .select({ count: count() })
// .from(userOrgs)
// .where(eq(userOrgs.userId, userId));
//
// // if the user is not an internal user and does not belong to any org, delete the entire user
// if (rootUser?.type !== UserType.Internal && !leftInOrgs.count) {
// await trx.delete(users).where(eq(users.userId, userId));
// }
// }
await calculateUserClientsForOrgs(userId, trx);
}); });
return response(res, { return response(res, {

View File

@@ -55,7 +55,9 @@ export default async function ProxyResourcesPage(
pagination = responseData.pagination; pagination = responseData.pagination;
} catch (e) {} } catch (e) {}
const siteIdParam = parsePositiveInt(searchParams.get("siteId") ?? undefined); const siteIdParam = parsePositiveInt(
searchParams.get("siteId") ?? undefined
);
let initialFilterSite: { let initialFilterSite: {
siteId: number; siteId: number;
@@ -122,6 +124,7 @@ export default async function ProxyResourcesPage(
domainId: resource.domainId || undefined, domainId: resource.domainId || undefined,
fullDomain: resource.fullDomain ?? null, fullDomain: resource.fullDomain ?? null,
ssl: resource.ssl, ssl: resource.ssl,
wildcard: resource.wildcard,
targets: resource.targets?.map((target) => ({ targets: resource.targets?.map((target) => ({
targetId: target.targetId, targetId: target.targetId,
ip: target.ip, ip: target.ip,

View File

@@ -96,6 +96,7 @@ export type ResourceRow = {
targets?: TargetHealth[]; targets?: TargetHealth[];
health?: "healthy" | "degraded" | "unhealthy" | "unknown"; health?: "healthy" | "degraded" | "unhealthy" | "unknown";
sites: ResourceSiteRow[]; sites: ResourceSiteRow[];
wildcard?: boolean;
}; };
function StatusIcon({ function StatusIcon({
@@ -570,10 +571,14 @@ export default function ProxyResourcesTable({
/> />
) : null} ) : null}
<div className=""> <div className="">
{!resourceRow.wildcard ? (
<CopyToClipboard <CopyToClipboard
text={resourceRow.domain} text={resourceRow.domain}
isLink={true} isLink={true}
/> />
) : (
<span>{resourceRow.domain}</span>
)}
</div> </div>
</div> </div>
); );