Merge branch 'dev' into feat/device-approvals

This commit is contained in:
Fred KISSIE
2026-01-14 23:08:12 +01:00
78 changed files with 2815 additions and 421 deletions

View File

@@ -78,6 +78,10 @@ export enum ActionsEnum {
updateSiteResource = "updateSiteResource",
createClient = "createClient",
deleteClient = "deleteClient",
archiveClient = "archiveClient",
unarchiveClient = "unarchiveClient",
blockClient = "blockClient",
unblockClient = "unblockClient",
updateClient = "updateClient",
listClients = "listClients",
getClient = "getClient",

View File

@@ -592,7 +592,8 @@ export const idp = pgTable("idp", {
type: varchar("type").notNull(),
defaultRoleMapping: varchar("defaultRoleMapping"),
defaultOrgMapping: varchar("defaultOrgMapping"),
autoProvision: boolean("autoProvision").notNull().default(false)
autoProvision: boolean("autoProvision").notNull().default(false),
tags: text("tags")
});
export const idpOidcConfig = pgTable("idpOidcConfig", {
@@ -690,6 +691,8 @@ export const clients = pgTable("clients", {
// endpoint: varchar("endpoint"),
lastHolePunch: integer("lastHolePunch"),
maxConnections: integer("maxConnections"),
archived: boolean("archived").notNull().default(false),
blocked: boolean("blocked").notNull().default(false),
approvalState: varchar("approvalState")
.$type<"pending" | "approved" | "denied">()
.default("approved")
@@ -730,7 +733,8 @@ export const olms = pgTable("olms", {
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
})
}),
archived: boolean("archived").notNull().default(false)
});
export const olmSessions = pgTable("clientSession", {

View File

@@ -1,4 +1,4 @@
import { db, loginPage, LoginPage, loginPageOrg, Org, orgs } from "@server/db";
import { db, loginPage, LoginPage, loginPageOrg, Org, orgs, roles } from "@server/db";
import {
Resource,
ResourcePassword,
@@ -108,9 +108,17 @@ export async function getUserSessionWithUser(
*/
export async function getUserOrgRole(userId: string, orgId: string) {
const userOrgRole = await db
.select()
.select({
userId: userOrgs.userId,
orgId: userOrgs.orgId,
roleId: userOrgs.roleId,
isOwner: userOrgs.isOwner,
autoProvisioned: userOrgs.autoProvisioned,
roleName: roles.name
})
.from(userOrgs)
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)))
.leftJoin(roles, eq(userOrgs.roleId, roles.roleId))
.limit(1);
return userOrgRole.length > 0 ? userOrgRole[0] : null;

View File

@@ -385,7 +385,9 @@ export const clients = sqliteTable("clients", {
type: text("type").notNull(), // "olm"
online: integer("online", { mode: "boolean" }).notNull().default(false),
// endpoint: text("endpoint"),
lastHolePunch: integer("lastHolePunch")
lastHolePunch: integer("lastHolePunch"),
archived: integer("archived", { mode: "boolean" }).notNull().default(false),
blocked: integer("blocked", { mode: "boolean" }).notNull().default(false)
});
export const clientSitesAssociationsCache = sqliteTable(
@@ -425,7 +427,8 @@ export const olms = sqliteTable("olms", {
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
})
}),
archived: integer("archived", { mode: "boolean" }).notNull().default(false)
});
export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", {
@@ -779,7 +782,8 @@ export const idp = sqliteTable("idp", {
mode: "boolean"
})
.notNull()
.default(false)
.default(false),
tags: text("tags")
});
// Identity Provider OAuth Configuration

View File

@@ -290,8 +290,8 @@ export const ClientResourceSchema = z
alias: z
.string()
.regex(
/^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$/,
"Alias must be a fully qualified domain name (e.g., example.com)"
/^(?:[a-zA-Z0-9*?](?:[a-zA-Z0-9*?-]{0,61}[a-zA-Z0-9*?])?\.)+[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$/,
"Alias must be a fully qualified domain name with optional wildcards (e.g., example.com, *.example.com, host-0?.example.internal)"
)
.optional(),
roles: z

View File

@@ -13,3 +13,4 @@ export * from "./verifyApiKeyIsRoot";
export * from "./verifyApiKeyApiKeyAccess";
export * from "./verifyApiKeyClientAccess";
export * from "./verifyApiKeySiteResourceAccess";
export * from "./verifyApiKeyIdpAccess";

View File

@@ -0,0 +1,88 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { idp, idpOrg, apiKeyOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
export async function verifyApiKeyIdpAccess(
req: Request,
res: Response,
next: NextFunction
) {
try {
const apiKey = req.apiKey;
const idpId = req.params.idpId || req.body.idpId || req.query.idpId;
const orgId = req.params.orgId;
if (!apiKey) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Key not authenticated")
);
}
if (!orgId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid organization ID")
);
}
if (!idpId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid IDP ID")
);
}
if (apiKey.isRoot) {
// Root keys can access any IDP in any org
return next();
}
const [idpRes] = await db
.select()
.from(idp)
.innerJoin(idpOrg, eq(idp.idpId, idpOrg.idpId))
.where(and(eq(idp.idpId, idpId), eq(idpOrg.orgId, orgId)))
.limit(1);
if (!idpRes || !idpRes.idp || !idpRes.idpOrg) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`IdP with ID ${idpId} not found for organization ${orgId}`
)
);
}
if (!req.apiKeyOrg) {
const apiKeyOrgRes = await db
.select()
.from(apiKeyOrg)
.where(
and(
eq(apiKeyOrg.apiKeyId, apiKey.apiKeyId),
eq(apiKeyOrg.orgId, idpRes.idpOrg.orgId)
)
);
req.apiKeyOrg = apiKeyOrgRes[0];
}
if (!req.apiKeyOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Key does not have access to this organization"
)
);
}
return next();
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying IDP access"
)
);
}
}

View File

@@ -139,6 +139,10 @@ export class PrivateConfig {
process.env.USE_PANGOLIN_DNS =
this.rawPrivateConfig.flags.use_pangolin_dns.toString();
}
if (this.rawPrivateConfig.flags.use_org_only_idp) {
process.env.USE_ORG_ONLY_IDP =
this.rawPrivateConfig.flags.use_org_only_idp.toString();
}
}
public getRawPrivateConfig() {

View File

@@ -288,7 +288,7 @@ export function selectBestExitNode(
const validNodes = pingResults.filter((n) => !n.error && n.weight > 0);
if (validNodes.length === 0) {
logger.error("No valid exit nodes available");
logger.debug("No valid exit nodes available");
return null;
}

View File

@@ -24,7 +24,9 @@ export class LockManager {
*/
async acquireLock(
lockKey: string,
ttlMs: number = 30000
ttlMs: number = 30000,
maxRetries: number = 3,
retryDelayMs: number = 100
): Promise<boolean> {
if (!redis || !redis.status || redis.status !== "ready") {
return true;
@@ -35,49 +37,67 @@ export class LockManager {
}:${Date.now()}`;
const redisKey = `lock:${lockKey}`;
try {
// Use SET with NX (only set if not exists) and PX (expire in milliseconds)
// This is atomic and handles both setting and expiration
const result = await redis.set(
redisKey,
lockValue,
"PX",
ttlMs,
"NX"
);
if (result === "OK") {
logger.debug(
`Lock acquired: ${lockKey} by ${
config.getRawConfig().gerbil.exit_node_name
}`
for (let attempt = 0; attempt < maxRetries; attempt++) {
try {
// Use SET with NX (only set if not exists) and PX (expire in milliseconds)
// This is atomic and handles both setting and expiration
const result = await redis.set(
redisKey,
lockValue,
"PX",
ttlMs,
"NX"
);
return true;
}
// Check if the existing lock is from this worker (reentrant behavior)
const existingValue = await redis.get(redisKey);
if (
existingValue &&
existingValue.startsWith(
`${config.getRawConfig().gerbil.exit_node_name}:`
)
) {
// Extend the lock TTL since it's the same worker
await redis.pexpire(redisKey, ttlMs);
logger.debug(
`Lock extended: ${lockKey} by ${
config.getRawConfig().gerbil.exit_node_name
}`
);
return true;
}
if (result === "OK") {
logger.debug(
`Lock acquired: ${lockKey} by ${
config.getRawConfig().gerbil.exit_node_name
}`
);
return true;
}
return false;
} catch (error) {
logger.error(`Failed to acquire lock ${lockKey}:`, error);
return false;
// Check if the existing lock is from this worker (reentrant behavior)
const existingValue = await redis.get(redisKey);
if (
existingValue &&
existingValue.startsWith(
`${config.getRawConfig().gerbil.exit_node_name}:`
)
) {
// Extend the lock TTL since it's the same worker
await redis.pexpire(redisKey, ttlMs);
logger.debug(
`Lock extended: ${lockKey} by ${
config.getRawConfig().gerbil.exit_node_name
}`
);
return true;
}
// If this isn't our last attempt, wait before retrying with exponential backoff
if (attempt < maxRetries - 1) {
const delay = retryDelayMs * Math.pow(2, attempt);
logger.debug(
`Lock ${lockKey} not available, retrying in ${delay}ms (attempt ${attempt + 1}/${maxRetries})`
);
await new Promise((resolve) => setTimeout(resolve, delay));
}
} catch (error) {
logger.error(`Failed to acquire lock ${lockKey} (attempt ${attempt + 1}/${maxRetries}):`, error);
// On error, still retry if we have attempts left
if (attempt < maxRetries - 1) {
const delay = retryDelayMs * Math.pow(2, attempt);
await new Promise((resolve) => setTimeout(resolve, delay));
}
}
}
logger.debug(
`Failed to acquire lock ${lockKey} after ${maxRetries} attempts`
);
return false;
}
/**

View File

@@ -83,7 +83,8 @@ export const privateConfigSchema = z.object({
flags: z
.object({
enable_redis: z.boolean().optional().default(false),
use_pangolin_dns: z.boolean().optional().default(false)
use_pangolin_dns: z.boolean().optional().default(false),
use_org_only_idp: z.boolean().optional().default(false)
})
.optional()
.prefault({}),

View File

@@ -456,11 +456,11 @@ export async function getTraefikConfig(
// );
} else if (resource.maintenanceModeType === "automatic") {
showMaintenancePage = !hasHealthyServers;
if (showMaintenancePage) {
logger.warn(
`Resource ${resource.name} (${fullDomain}) has no healthy servers - showing maintenance page (AUTOMATIC mode)`
);
}
// if (showMaintenancePage) {
// logger.warn(
// `Resource ${resource.name} (${fullDomain}) has no healthy servers - showing maintenance page (AUTOMATIC mode)`
// );
// }
}
}

View File

@@ -27,7 +27,18 @@ export async function verifyValidSubscription(
return next();
}
const tier = await getOrgTierData(req.params.orgId);
const orgId = req.params.orgId || req.body.orgId || req.query.orgId || req.userOrgId;
if (!orgId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Organization ID is required to verify subscription"
)
);
}
const tier = await getOrgTierData(orgId);
if (!tier.active) {
return next(

View File

@@ -455,18 +455,18 @@ authenticated.get(
authenticated.post(
"/re-key/:clientId/regenerate-client-secret",
verifyClientAccess, // this is first to set the org id
verifyValidLicense,
verifyValidSubscription,
verifyClientAccess,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateClientSecret
);
authenticated.post(
"/re-key/:siteId/regenerate-site-secret",
verifySiteAccess, // this is first to set the org id
verifyValidLicense,
verifyValidSubscription,
verifySiteAccess,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateSiteSecret
);

View File

@@ -18,7 +18,8 @@ import * as logs from "#private/routers/auditLogs";
import {
verifyApiKeyHasAction,
verifyApiKeyIsRoot,
verifyApiKeyOrgAccess
verifyApiKeyOrgAccess,
verifyApiKeyIdpAccess
} from "@server/middlewares";
import {
verifyValidSubscription,
@@ -31,6 +32,8 @@ import {
authenticated as a
} from "@server/routers/integration";
import { logActionAudit } from "#private/middlewares";
import config from "#private/lib/config";
import { build } from "@server/build";
export const unauthenticated = ua;
export const authenticated = a;
@@ -88,3 +91,49 @@ authenticated.get(
logActionAudit(ActionsEnum.exportLogs),
logs.exportAccessAuditLogs
);
authenticated.put(
"/org/:orgId/idp/oidc",
verifyValidLicense,
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.createIdp),
logActionAudit(ActionsEnum.createIdp),
orgIdp.createOrgOidcIdp
);
authenticated.post(
"/org/:orgId/idp/:idpId/oidc",
verifyValidLicense,
verifyApiKeyOrgAccess,
verifyApiKeyIdpAccess,
verifyApiKeyHasAction(ActionsEnum.updateIdp),
logActionAudit(ActionsEnum.updateIdp),
orgIdp.updateOrgOidcIdp
);
authenticated.delete(
"/org/:orgId/idp/:idpId",
verifyValidLicense,
verifyApiKeyOrgAccess,
verifyApiKeyIdpAccess,
verifyApiKeyHasAction(ActionsEnum.deleteIdp),
logActionAudit(ActionsEnum.deleteIdp),
orgIdp.deleteOrgIdp
);
authenticated.get(
"/org/:orgId/idp/:idpId",
verifyValidLicense,
verifyApiKeyOrgAccess,
verifyApiKeyIdpAccess,
verifyApiKeyHasAction(ActionsEnum.getIdp),
orgIdp.getOrgIdp
);
authenticated.get(
"/org/:orgId/idp",
verifyValidLicense,
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.listIdps),
orgIdp.listOrgIdps
);

View File

@@ -28,6 +28,7 @@ import { eq, InferInsertModel } from "drizzle-orm";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
import config from "@server/private/lib/config";
const paramsSchema = z.strictObject({
orgId: z.string()
@@ -94,8 +95,10 @@ export async function upsertLoginPageBranding(
typeof loginPageBranding
>;
if (build !== "saas") {
// org branding settings are only considered in the saas build
if (
build !== "saas" &&
!config.getRawPrivateConfig().flags.use_org_only_idp
) {
const { orgTitle, orgSubtitle, ...rest } = updateData;
updateData = rest;
}

View File

@@ -43,25 +43,27 @@ const bodySchema = z.strictObject({
scopes: z.string().nonempty(),
autoProvision: z.boolean().optional(),
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
roleMapping: z.string().optional()
roleMapping: z.string().optional(),
tags: z.string().optional()
});
// registry.registerPath({
// method: "put",
// path: "/idp/oidc",
// description: "Create an OIDC IdP.",
// tags: [OpenAPITags.Idp],
// request: {
// body: {
// content: {
// "application/json": {
// schema: bodySchema
// }
// }
// }
// },
// responses: {}
// });
registry.registerPath({
method: "put",
path: "/org/{orgId}/idp/oidc",
description: "Create an OIDC IdP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
request: {
params: paramsSchema,
body: {
content: {
"application/json": {
schema: bodySchema
}
}
}
},
responses: {}
});
export async function createOrgOidcIdp(
req: Request,
@@ -103,7 +105,8 @@ export async function createOrgOidcIdp(
name,
autoProvision,
variant,
roleMapping
roleMapping,
tags
} = parsedBody.data;
if (build === "saas") {
@@ -131,7 +134,8 @@ export async function createOrgOidcIdp(
.values({
name,
autoProvision,
type: "oidc"
type: "oidc",
tags
})
.returning();

View File

@@ -32,9 +32,9 @@ const paramsSchema = z
registry.registerPath({
method: "delete",
path: "/idp/{idpId}",
description: "Delete IDP.",
tags: [OpenAPITags.Idp],
path: "/org/{orgId}/idp/{idpId}",
description: "Delete IDP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
request: {
params: paramsSchema
},

View File

@@ -48,16 +48,16 @@ async function query(idpId: number, orgId: string) {
return res;
}
// registry.registerPath({
// method: "get",
// path: "/idp/{idpId}",
// description: "Get an IDP by its IDP ID.",
// tags: [OpenAPITags.Idp],
// request: {
// params: paramsSchema
// },
// responses: {}
// });
registry.registerPath({
method: "get",
path: "/org/:orgId/idp/:idpId",
description: "Get an IDP by its IDP ID for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
request: {
params: paramsSchema
},
responses: {}
});
export async function getOrgIdp(
req: Request,

View File

@@ -50,7 +50,8 @@ async function query(orgId: string, limit: number, offset: number) {
orgId: idpOrg.orgId,
name: idp.name,
type: idp.type,
variant: idpOidcConfig.variant
variant: idpOidcConfig.variant,
tags: idp.tags
})
.from(idpOrg)
.where(eq(idpOrg.orgId, orgId))
@@ -62,16 +63,17 @@ async function query(orgId: string, limit: number, offset: number) {
return res;
}
// registry.registerPath({
// method: "get",
// path: "/idp",
// description: "List all IDP in the system.",
// tags: [OpenAPITags.Idp],
// request: {
// query: querySchema
// },
// responses: {}
// });
registry.registerPath({
method: "get",
path: "/org/{orgId}/idp",
description: "List all IDP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
request: {
query: querySchema,
params: paramsSchema
},
responses: {}
});
export async function listOrgIdps(
req: Request,

View File

@@ -46,30 +46,31 @@ const bodySchema = z.strictObject({
namePath: z.string().optional(),
scopes: z.string().optional(),
autoProvision: z.boolean().optional(),
roleMapping: z.string().optional()
roleMapping: z.string().optional(),
tags: z.string().optional()
});
export type UpdateOrgIdpResponse = {
idpId: number;
};
// registry.registerPath({
// method: "post",
// path: "/idp/{idpId}/oidc",
// description: "Update an OIDC IdP.",
// tags: [OpenAPITags.Idp],
// request: {
// params: paramsSchema,
// body: {
// content: {
// "application/json": {
// schema: bodySchema
// }
// }
// }
// },
// responses: {}
// });
registry.registerPath({
method: "post",
path: "/org/{orgId}/idp/{idpId}/oidc",
description: "Update an OIDC IdP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
request: {
params: paramsSchema,
body: {
content: {
"application/json": {
schema: bodySchema
}
}
}
},
responses: {}
});
export async function updateOrgOidcIdp(
req: Request,
@@ -109,7 +110,8 @@ export async function updateOrgOidcIdp(
namePath,
name,
autoProvision,
roleMapping
roleMapping,
tags
} = parsedBody.data;
if (build === "saas") {
@@ -167,7 +169,8 @@ export async function updateOrgOidcIdp(
await db.transaction(async (trx) => {
const idpData = {
name,
autoProvision
autoProvision,
tags
};
// only update if at least one key is not undefined

View File

@@ -16,4 +16,4 @@ export * from "./checkResourceSession";
export * from "./securityKey";
export * from "./startDeviceWebAuth";
export * from "./verifyDeviceWebAuth";
export * from "./pollDeviceWebAuth";
export * from "./pollDeviceWebAuth";

View File

@@ -942,7 +942,7 @@ async function isUserAllowedToAccessResource(
username: user.username,
email: user.email,
name: user.name,
role: user.role
role: userOrgRole.roleName
};
}
@@ -956,7 +956,7 @@ async function isUserAllowedToAccessResource(
username: user.username,
email: user.email,
name: user.name,
role: user.role
role: userOrgRole.roleName
};
}

View File

@@ -0,0 +1,105 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { sendTerminateClient } from "./terminate";
const archiveClientSchema = z.strictObject({
clientId: z.string().transform(Number).pipe(z.int().positive())
});
registry.registerPath({
method: "post",
path: "/client/{clientId}/archive",
description: "Archive a client by its client ID.",
tags: [OpenAPITags.Client],
request: {
params: archiveClientSchema
},
responses: {}
});
export async function archiveClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = archiveClientSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
// Check if client exists
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
if (client.archived) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Client with ID ${clientId} is already archived`
)
);
}
await db.transaction(async (trx) => {
// Archive the client
await trx
.update(clients)
.set({ archived: true })
.where(eq(clients.clientId, clientId));
// Rebuild associations to clean up related data
await rebuildClientAssociationsFromClient(client, trx);
// Send terminate signal if there's an associated OLM
if (client.olmId) {
await sendTerminateClient(client.clientId, client.olmId);
}
});
return response(res, {
data: null,
success: true,
error: false,
message: "Client archived successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to archive client"
)
);
}
}

View File

@@ -0,0 +1,101 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { sendTerminateClient } from "./terminate";
const blockClientSchema = z.strictObject({
clientId: z.string().transform(Number).pipe(z.int().positive())
});
registry.registerPath({
method: "post",
path: "/client/{clientId}/block",
description: "Block a client by its client ID.",
tags: [OpenAPITags.Client],
request: {
params: blockClientSchema
},
responses: {}
});
export async function blockClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = blockClientSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
// Check if client exists
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
if (client.blocked) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Client with ID ${clientId} is already blocked`
)
);
}
await db.transaction(async (trx) => {
// Block the client
await trx
.update(clients)
.set({ blocked: true })
.where(eq(clients.clientId, clientId));
// Send terminate signal if there's an associated OLM and it's connected
if (client.olmId && client.online) {
await sendTerminateClient(client.clientId, client.olmId);
}
});
return response(res, {
data: null,
success: true,
error: false,
message: "Client blocked successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to block client"
)
);
}
}

View File

@@ -60,11 +60,12 @@ export async function deleteClient(
);
}
// Only allow deletion of machine clients (clients without userId)
if (client.userId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Cannot delete a user client with this endpoint`
`Cannot delete a user client. User clients must be archived instead.`
)
);
}

View File

@@ -1,6 +1,10 @@
export * from "./pickClientDefaults";
export * from "./createClient";
export * from "./deleteClient";
export * from "./archiveClient";
export * from "./unarchiveClient";
export * from "./blockClient";
export * from "./unblockClient";
export * from "./listClients";
export * from "./updateClient";
export * from "./getClient";

View File

@@ -137,7 +137,10 @@ function queryClients(
userEmail: users.email,
niceId: clients.niceId,
agent: olms.agent,
approvalState: clients.approvalState
approvalState: clients.approvalState,
olmArchived: olms.archived,
archived: clients.archived,
blocked: clients.blocked
})
.from(clients)
.leftJoin(orgs, eq(clients.orgId, orgs.orgId))

View File

@@ -0,0 +1,93 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const unarchiveClientSchema = z.strictObject({
clientId: z.string().transform(Number).pipe(z.int().positive())
});
registry.registerPath({
method: "post",
path: "/client/{clientId}/unarchive",
description: "Unarchive a client by its client ID.",
tags: [OpenAPITags.Client],
request: {
params: unarchiveClientSchema
},
responses: {}
});
export async function unarchiveClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = unarchiveClientSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
// Check if client exists
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
if (!client.archived) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Client with ID ${clientId} is not archived`
)
);
}
// Unarchive the client
await db
.update(clients)
.set({ archived: false })
.where(eq(clients.clientId, clientId));
return response(res, {
data: null,
success: true,
error: false,
message: "Client unarchived successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to unarchive client"
)
);
}
}

View File

@@ -0,0 +1,93 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const unblockClientSchema = z.strictObject({
clientId: z.string().transform(Number).pipe(z.int().positive())
});
registry.registerPath({
method: "post",
path: "/client/{clientId}/unblock",
description: "Unblock a client by its client ID.",
tags: [OpenAPITags.Client],
request: {
params: unblockClientSchema
},
responses: {}
});
export async function unblockClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = unblockClientSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
// Check if client exists
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
if (!client.blocked) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Client with ID ${clientId} is not blocked`
)
);
}
// Unblock the client
await db
.update(clients)
.set({ blocked: false })
.where(eq(clients.clientId, clientId));
return response(res, {
data: null,
success: true,
error: false,
message: "Client unblocked successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to unblock client"
)
);
}
}

View File

@@ -174,6 +174,38 @@ authenticated.delete(
client.deleteClient
);
authenticated.post(
"/client/:clientId/archive",
verifyClientAccess,
verifyUserHasAction(ActionsEnum.archiveClient),
logActionAudit(ActionsEnum.archiveClient),
client.archiveClient
);
authenticated.post(
"/client/:clientId/unarchive",
verifyClientAccess,
verifyUserHasAction(ActionsEnum.unarchiveClient),
logActionAudit(ActionsEnum.unarchiveClient),
client.unarchiveClient
);
authenticated.post(
"/client/:clientId/block",
verifyClientAccess,
verifyUserHasAction(ActionsEnum.blockClient),
logActionAudit(ActionsEnum.blockClient),
client.blockClient
);
authenticated.post(
"/client/:clientId/unblock",
verifyClientAccess,
verifyUserHasAction(ActionsEnum.unblockClient),
logActionAudit(ActionsEnum.unblockClient),
client.unblockClient
);
authenticated.post(
"/client/:clientId",
verifyClientAccess, // this will check if the user has access to the client
@@ -816,11 +848,18 @@ authenticated.put("/user/:userId/olm", verifyIsLoggedInUser, olm.createUserOlm);
authenticated.get("/user/:userId/olms", verifyIsLoggedInUser, olm.listUserOlms);
authenticated.delete(
"/user/:userId/olm/:olmId",
authenticated.post(
"/user/:userId/olm/:olmId/archive",
verifyIsLoggedInUser,
verifyOlmAccess,
olm.deleteUserOlm
olm.archiveUserOlm
);
authenticated.post(
"/user/:userId/olm/:olmId/unarchive",
verifyIsLoggedInUser,
verifyOlmAccess,
olm.unarchiveUserOlm
);
authenticated.get(

View File

@@ -24,7 +24,8 @@ const bodySchema = z.strictObject({
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().nonempty(),
autoProvision: z.boolean().optional()
autoProvision: z.boolean().optional(),
tags: z.string().optional()
});
export type CreateIdpResponse = {
@@ -75,7 +76,8 @@ export async function createOidcIdp(
emailPath,
namePath,
name,
autoProvision
autoProvision,
tags
} = parsedBody.data;
const key = config.getRawConfig().server.secret!;
@@ -90,7 +92,8 @@ export async function createOidcIdp(
.values({
name,
autoProvision,
type: "oidc"
type: "oidc",
tags
})
.returning();

View File

@@ -33,7 +33,8 @@ async function query(limit: number, offset: number) {
type: idp.type,
variant: idpOidcConfig.variant,
orgCount: sql<number>`count(${idpOrg.orgId})`,
autoProvision: idp.autoProvision
autoProvision: idp.autoProvision,
tags: idp.tags
})
.from(idp)
.leftJoin(idpOrg, sql`${idp.idpId} = ${idpOrg.idpId}`)

View File

@@ -30,7 +30,8 @@ const bodySchema = z.strictObject({
scopes: z.string().optional(),
autoProvision: z.boolean().optional(),
defaultRoleMapping: z.string().optional(),
defaultOrgMapping: z.string().optional()
defaultOrgMapping: z.string().optional(),
tags: z.string().optional()
});
export type UpdateIdpResponse = {
@@ -94,7 +95,8 @@ export async function updateOidcIdp(
name,
autoProvision,
defaultRoleMapping,
defaultOrgMapping
defaultOrgMapping,
tags
} = parsedBody.data;
// Check if IDP exists and is of type OIDC
@@ -127,7 +129,8 @@ export async function updateOidcIdp(
name,
autoProvision,
defaultRoleMapping,
defaultOrgMapping
defaultOrgMapping,
tags
};
// only update if at least one key is not undefined

View File

@@ -759,9 +759,10 @@ authenticated.post(
);
authenticated.get(
"/idp",
verifyApiKeyIsRoot,
verifyApiKeyHasAction(ActionsEnum.listIdps),
"/idp", // no guards on this because anyone can list idps for login purposes
// we do the same for the external api
// verifyApiKeyIsRoot,
// verifyApiKeyHasAction(ActionsEnum.listIdps),
idp.listIdps
);
@@ -850,6 +851,38 @@ authenticated.delete(
client.deleteClient
);
authenticated.post(
"/client/:clientId/archive",
verifyApiKeyClientAccess,
verifyApiKeyHasAction(ActionsEnum.archiveClient),
logActionAudit(ActionsEnum.archiveClient),
client.archiveClient
);
authenticated.post(
"/client/:clientId/unarchive",
verifyApiKeyClientAccess,
verifyApiKeyHasAction(ActionsEnum.unarchiveClient),
logActionAudit(ActionsEnum.unarchiveClient),
client.unarchiveClient
);
authenticated.post(
"/client/:clientId/block",
verifyApiKeyClientAccess,
verifyApiKeyHasAction(ActionsEnum.blockClient),
logActionAudit(ActionsEnum.blockClient),
client.blockClient
);
authenticated.post(
"/client/:clientId/unblock",
verifyApiKeyClientAccess,
verifyApiKeyHasAction(ActionsEnum.unblockClient),
logActionAudit(ActionsEnum.unblockClient),
client.unblockClient
);
authenticated.post(
"/client/:clientId",
verifyApiKeyClientAccess,

View File

@@ -0,0 +1,81 @@
import { NextFunction, Request, Response } from "express";
import { db } from "@server/db";
import { olms, clients } from "@server/db";
import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { sendTerminateClient } from "../client/terminate";
const paramsSchema = z
.object({
userId: z.string(),
olmId: z.string()
})
.strict();
export async function archiveUserOlm(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { olmId } = parsedParams.data;
// Archive the OLM and disconnect associated clients in a transaction
await db.transaction(async (trx) => {
// Find all clients associated with this OLM
const associatedClients = await trx
.select()
.from(clients)
.where(eq(clients.olmId, olmId));
// Disconnect clients from the OLM (set olmId to null)
for (const client of associatedClients) {
await trx
.update(clients)
.set({ olmId: null })
.where(eq(clients.clientId, client.clientId));
await rebuildClientAssociationsFromClient(client, trx);
await sendTerminateClient(client.clientId, olmId);
}
// Archive the OLM (set archived to true)
await trx
.update(olms)
.set({ archived: true })
.where(eq(olms.olmId, olmId));
});
return response(res, {
data: null,
success: true,
error: false,
message: "Device archived successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to archive device"
)
);
}
}

View File

@@ -1,6 +1,6 @@
import { NextFunction, Request, Response } from "express";
import { db } from "@server/db";
import { olms } from "@server/db";
import { olms, clients } from "@server/db";
import { eq, and } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -17,6 +17,10 @@ const paramsSchema = z
})
.strict();
const querySchema = z.object({
orgId: z.string().optional()
});
// registry.registerPath({
// method: "get",
// path: "/user/{userId}/olm/{olmId}",
@@ -44,15 +48,56 @@ export async function getUserOlm(
);
}
const parsedQuery = querySchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error).toString()
)
);
}
const { olmId, userId } = parsedParams.data;
const { orgId } = parsedQuery.data;
const [olm] = await db
.select()
.from(olms)
.where(and(eq(olms.userId, userId), eq(olms.olmId, olmId)));
if (!olm) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Olm not found"
)
);
}
// If orgId is provided and olm has a clientId, fetch the client to check blocked status
let blocked: boolean | undefined;
if (orgId && olm.clientId) {
const [client] = await db
.select({ blocked: clients.blocked })
.from(clients)
.where(
and(
eq(clients.clientId, olm.clientId),
eq(clients.orgId, orgId)
)
)
.limit(1);
blocked = client?.blocked ?? false;
}
const responseData = blocked !== undefined
? { ...olm, blocked }
: olm;
return response(res, {
data: olm,
data: responseData,
success: true,
error: false,
message: "Successfully retrieved olm",

View File

@@ -1,7 +1,7 @@
import { db } from "@server/db";
import { disconnectClient } from "#dynamic/routers/ws";
import { MessageHandler } from "@server/routers/ws";
import { clients, Olm } from "@server/db";
import { clients, olms, Olm } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm";
import logger from "@server/logger";
import { validateSessionToken } from "@server/auth/sessions/app";
@@ -108,29 +108,17 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
return;
}
if (olm.userId) {
// we need to check a user token to make sure its still valid
const { session: userSession, user } =
await validateSessionToken(userToken);
if (!userSession || !user) {
logger.warn("Invalid user session for olm ping");
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
}
if (user.userId !== olm.userId) {
logger.warn("User ID mismatch for olm ping");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
try {
// get the client
const [client] = await db
.select()
.from(clients)
.where(
and(
eq(clients.olmId, olm.olmId),
eq(clients.userId, olm.userId)
)
)
.where(eq(clients.clientId, olm.clientId))
.limit(1);
if (!client) {
@@ -138,38 +126,62 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
return;
}
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(userToken))
);
const policyCheck = await checkOrgAccessPolicy({
orgId: client.orgId,
userId: olm.userId,
sessionId // this is the user token passed in the message
});
if (!policyCheck.allowed) {
logger.warn(
`Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}`
);
if (client.blocked) {
// NOTE: by returning we dont update the lastPing, so the offline checker will eventually disconnect them
logger.debug(`Blocked client ${client.clientId} attempted olm ping`);
return;
}
}
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
if (olm.userId) {
// we need to check a user token to make sure its still valid
const { session: userSession, user } =
await validateSessionToken(userToken);
if (!userSession || !user) {
logger.warn("Invalid user session for olm ping");
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
}
if (user.userId !== olm.userId) {
logger.warn("User ID mismatch for olm ping");
return;
}
if (user.userId !== client.userId) {
logger.warn("Client user ID mismatch for olm ping");
return;
}
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(userToken))
);
const policyCheck = await checkOrgAccessPolicy({
orgId: client.orgId,
userId: olm.userId,
sessionId // this is the user token passed in the message
});
if (!policyCheck.allowed) {
logger.warn(
`Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}`
);
return;
}
}
try {
// Update the client's last ping timestamp
await db
.update(clients)
.set({
lastPing: Math.floor(Date.now() / 1000),
online: true
online: true,
archived: false
})
.where(eq(clients.clientId, olm.clientId));
if (olm.archived) {
await db
.update(olms)
.set({ archived: false })
.where(eq(olms.olmId, olm.olmId));
}
} catch (error) {
logger.error("Error handling ping message", { error });
}

View File

@@ -55,6 +55,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return;
}
if (client.blocked) {
logger.debug(`Client ${client.clientId} is blocked. Ignoring register.`);
return;
}
const [org] = await db
.select()
.from(orgs)
@@ -112,18 +117,20 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (
(olmVersion && olm.version !== olmVersion) ||
(olmAgent && olm.agent !== olmAgent)
(olmAgent && olm.agent !== olmAgent) ||
olm.archived
) {
await db
.update(olms)
.set({
version: olmVersion,
agent: olmAgent
agent: olmAgent,
archived: false
})
.where(eq(olms.olmId, olm.olmId));
}
if (client.pubKey !== publicKey) {
if (client.pubKey !== publicKey || client.archived) {
logger.info(
"Public key mismatch. Updating public key and clearing session info..."
);
@@ -131,7 +138,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
await db
.update(clients)
.set({
pubKey: publicKey
pubKey: publicKey,
archived: false,
})
.where(eq(clients.clientId, client.clientId));

View File

@@ -3,9 +3,9 @@ export * from "./getOlmToken";
export * from "./createUserOlm";
export * from "./handleOlmRelayMessage";
export * from "./handleOlmPingMessage";
export * from "./deleteUserOlm";
export * from "./archiveUserOlm";
export * from "./unarchiveUserOlm";
export * from "./listUserOlms";
export * from "./deleteUserOlm";
export * from "./getUserOlm";
export * from "./handleOlmServerPeerAddMessage";
export * from "./handleOlmUnRelayMessage";

View File

@@ -51,6 +51,7 @@ export type ListUserOlmsResponse = {
name: string | null;
clientId: number | null;
userId: string | null;
archived: boolean;
}>;
pagination: {
total: number;
@@ -89,7 +90,7 @@ export async function listUserOlms(
const { userId } = parsedParams.data;
// Get total count
// Get total count (including archived OLMs)
const [totalCountResult] = await db
.select({ count: count() })
.from(olms)
@@ -97,7 +98,7 @@ export async function listUserOlms(
const total = totalCountResult?.count || 0;
// Get OLMs for the current user
// Get OLMs for the current user (including archived OLMs)
const userOlms = await db
.select({
olmId: olms.olmId,
@@ -105,7 +106,8 @@ export async function listUserOlms(
version: olms.version,
name: olms.name,
clientId: olms.clientId,
userId: olms.userId
userId: olms.userId,
archived: olms.archived
})
.from(olms)
.where(eq(olms.userId, userId))

View File

@@ -0,0 +1,84 @@
import { NextFunction, Request, Response } from "express";
import { db } from "@server/db";
import { olms } from "@server/db";
import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
const paramsSchema = z
.object({
userId: z.string(),
olmId: z.string()
})
.strict();
export async function unarchiveUserOlm(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { olmId } = parsedParams.data;
// Check if OLM exists and is archived
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.olmId, olmId))
.limit(1);
if (!olm) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`OLM with ID ${olmId} not found`
)
);
}
if (!olm.archived) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`OLM with ID ${olmId} is not archived`
)
);
}
// Unarchive the OLM (set archived to false)
await db
.update(olms)
.set({ archived: false })
.where(eq(olms.olmId, olmId));
return response(res, {
data: null,
success: true,
error: false,
message: "Device unarchived successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to unarchive device"
)
);
}
}