Merge pull request #2121 from Fredkiss3/feat/device-approvals

feat: device approvals
This commit is contained in:
Milo Schwartz
2026-01-15 21:33:31 -08:00
committed by GitHub
39 changed files with 1549 additions and 286 deletions

View File

@@ -129,7 +129,9 @@ export enum ActionsEnum {
getBlueprint = "getBlueprint",
applyBlueprint = "applyBlueprint",
viewLogs = "viewLogs",
exportLogs = "exportLogs"
exportLogs = "exportLogs",
listApprovals = "listApprovals",
updateApprovals = "updateApprovals"
}
export async function checkUserActionPermission(

View File

@@ -10,7 +10,15 @@ import {
index
} from "drizzle-orm/pg-core";
import { InferSelectModel } from "drizzle-orm";
import { domains, orgs, targets, users, exitNodes, sessions } from "./schema";
import {
domains,
orgs,
targets,
users,
exitNodes,
sessions,
clients
} from "./schema";
export const certificates = pgTable("certificates", {
certId: serial("certId").primaryKey(),
@@ -289,6 +297,33 @@ export const accessAuditLog = pgTable(
]
);
export const approvals = pgTable("approvals", {
approvalId: serial("approvalId").primaryKey(),
timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds
orgId: varchar("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
})
.notNull(),
clientId: integer("clientId").references(() => clients.clientId, {
onDelete: "cascade"
}), // clients reference user devices (in this case)
userId: varchar("userId")
.references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
})
.notNull(),
decision: varchar("decision")
.$type<"approved" | "denied" | "pending">()
.default("pending")
.notNull(),
type: varchar("type")
.$type<"user_device" /*| 'proxy' // for later */>()
.notNull()
});
export type Approval = InferSelectModel<typeof approvals>;
export type Limit = InferSelectModel<typeof limits>;
export type Account = InferSelectModel<typeof account>;
export type Certificate = InferSelectModel<typeof certificates>;

View File

@@ -365,7 +365,8 @@ export const roles = pgTable("roles", {
.notNull(),
isAdmin: boolean("isAdmin"),
name: varchar("name").notNull(),
description: varchar("description")
description: varchar("description"),
requireDeviceApproval: boolean("requireDeviceApproval").default(false)
});
export const roleActions = pgTable("roleActions", {
@@ -691,7 +692,10 @@ export const clients = pgTable("clients", {
lastHolePunch: integer("lastHolePunch"),
maxConnections: integer("maxConnections"),
archived: boolean("archived").notNull().default(false),
blocked: boolean("blocked").notNull().default(false)
blocked: boolean("blocked").notNull().default(false),
approvalState: varchar("approvalState").$type<
"pending" | "approved" | "denied"
>()
});
export const clientSitesAssociationsCache = pgTable(

View File

@@ -6,7 +6,7 @@ import {
sqliteTable,
text
} from "drizzle-orm/sqlite-core";
import { domains, exitNodes, orgs, sessions, users } from "./schema";
import { clients, domains, exitNodes, orgs, sessions, users } from "./schema";
export const certificates = sqliteTable("certificates", {
certId: integer("certId").primaryKey({ autoIncrement: true }),
@@ -289,6 +289,31 @@ export const accessAuditLog = sqliteTable(
]
);
export const approvals = sqliteTable("approvals", {
approvalId: integer("approvalId").primaryKey({ autoIncrement: true }),
timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds
orgId: text("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
})
.notNull(),
clientId: integer("clientId").references(() => clients.clientId, {
onDelete: "cascade"
}), // olms reference user devices clients
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
}),
decision: text("decision")
.$type<"approved" | "denied" | "pending">()
.default("pending")
.notNull(),
type: text("type")
.$type<"user_device" /*| 'proxy' // for later */>()
.notNull()
});
export type Approval = InferSelectModel<typeof approvals>;
export type Limit = InferSelectModel<typeof limits>;
export type Account = InferSelectModel<typeof account>;
export type Certificate = InferSelectModel<typeof certificates>;

View File

@@ -387,7 +387,10 @@ export const clients = sqliteTable("clients", {
// endpoint: text("endpoint"),
lastHolePunch: integer("lastHolePunch"),
archived: integer("archived", { mode: "boolean" }).notNull().default(false),
blocked: integer("blocked", { mode: "boolean" }).notNull().default(false)
blocked: integer("blocked", { mode: "boolean" }).notNull().default(false),
approvalState: text("approvalState").$type<
"pending" | "approved" | "denied"
>()
});
export const clientSitesAssociationsCache = sqliteTable(
@@ -604,7 +607,10 @@ export const roles = sqliteTable("roles", {
.notNull(),
isAdmin: integer("isAdmin", { mode: "boolean" }),
name: text("name").notNull(),
description: text("description")
description: text("description"),
requireDeviceApproval: integer("requireDeviceApproval", {
mode: "boolean"
}).default(false)
});
export const roleActions = sqliteTable("roleActions", {

View File

@@ -1,21 +1,24 @@
import { listExitNodes } from "#dynamic/lib/exitNodes";
import { build } from "@server/build";
import {
approvals,
clients,
db,
olms,
orgs,
roleClients,
roles,
Transaction,
userClients,
userOrgs,
Transaction
userOrgs
} from "@server/db";
import { eq, and, notInArray } from "drizzle-orm";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
import logger from "@server/logger";
import { rebuildClientAssociationsFromClient } from "./rebuildClientAssociations";
import { sendTerminateClient } from "@server/routers/client/terminate";
import { getUniqueClientName } from "@server/db/names";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
import { isLicensedOrSubscribed } from "@server/lib/isLicencedOrSubscribed";
import logger from "@server/logger";
import { sendTerminateClient } from "@server/routers/client/terminate";
import { and, eq, notInArray, type InferInsertModel } from "drizzle-orm";
import { rebuildClientAssociationsFromClient } from "./rebuildClientAssociations";
export async function calculateUserClientsForOrgs(
userId: string,
@@ -38,13 +41,15 @@ export async function calculateUserClientsForOrgs(
const allUserOrgs = await transaction
.select()
.from(userOrgs)
.innerJoin(roles, eq(roles.roleId, userOrgs.roleId))
.where(eq(userOrgs.userId, userId));
const userOrgIds = allUserOrgs.map((uo) => uo.orgId);
const userOrgIds = allUserOrgs.map(({ userOrgs: uo }) => uo.orgId);
// For each OLM, ensure there's a client in each org the user is in
for (const olm of userOlms) {
for (const userOrg of allUserOrgs) {
for (const userRoleOrg of allUserOrgs) {
const { userOrgs: userOrg, roles: role } = userRoleOrg;
const orgId = userOrg.orgId;
const [org] = await transaction
@@ -182,21 +187,46 @@ export async function calculateUserClientsForOrgs(
const niceId = await getUniqueClientName(orgId);
const isOrgLicensed = await isLicensedOrSubscribed(
userOrg.orgId
);
const requireApproval =
build !== "oss" &&
isOrgLicensed &&
role.requireDeviceApproval;
const newClientData: InferInsertModel<typeof clients> = {
userId,
orgId: userOrg.orgId,
exitNodeId: randomExitNode.exitNodeId,
name: olm.name || "User Client",
subnet: updatedSubnet,
olmId: olm.olmId,
type: "olm",
niceId,
approvalState: requireApproval ? "pending" : null
};
// Create the client
const [newClient] = await transaction
.insert(clients)
.values({
userId,
orgId: userOrg.orgId,
exitNodeId: randomExitNode.exitNodeId,
name: olm.name || "User Client",
subnet: updatedSubnet,
olmId: olm.olmId,
type: "olm",
niceId
})
.values(newClientData)
.returning();
// create approval request
if (requireApproval) {
await transaction
.insert(approvals)
.values({
timestamp: Math.floor(new Date().getTime() / 1000),
orgId: userOrg.orgId,
clientId: newClient.clientId,
userId,
type: "user_device"
})
.returning();
}
await rebuildClientAssociationsFromClient(
newClient,
transaction

View File

@@ -0,0 +1,15 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
export * from "./listApprovals";
export * from "./processPendingApproval";

View File

@@ -0,0 +1,188 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import type { Request, Response, NextFunction } from "express";
import { build } from "@server/build";
import { getOrgTierData } from "@server/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { approvals, clients, db, users, type Approval } from "@server/db";
import { eq, isNull, sql, not, and, desc } from "drizzle-orm";
import response from "@server/lib/response";
const paramsSchema = z.strictObject({
orgId: z.string()
});
const querySchema = z.strictObject({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative()),
approvalState: z
.enum(["pending", "approved", "denied", "all"])
.optional()
.default("all")
.catch("all")
});
async function queryApprovals(
orgId: string,
limit: number,
offset: number,
approvalState: z.infer<typeof querySchema>["approvalState"]
) {
let state: Array<Approval["decision"]> = [];
switch (approvalState) {
case "pending":
state = ["pending"];
break;
case "approved":
state = ["approved"];
break;
case "denied":
state = ["denied"];
break;
default:
state = ["approved", "denied", "pending"];
}
const res = await db
.select({
approvalId: approvals.approvalId,
orgId: approvals.orgId,
clientId: approvals.clientId,
decision: approvals.decision,
type: approvals.type,
user: {
name: users.name,
userId: users.userId,
username: users.username
}
})
.from(approvals)
.innerJoin(users, and(eq(approvals.userId, users.userId)))
.leftJoin(
clients,
and(
eq(approvals.clientId, clients.clientId),
not(isNull(clients.userId)) // only user devices
)
)
.where(
and(
eq(approvals.orgId, orgId),
sql`${approvals.decision} in ${state}`
)
)
.orderBy(
sql`CASE ${approvals.decision} WHEN 'pending' THEN 0 ELSE 1 END`,
desc(approvals.timestamp)
)
.limit(limit)
.offset(offset);
return res;
}
export type ListApprovalsResponse = {
approvals: NonNullable<Awaited<ReturnType<typeof queryApprovals>>>;
pagination: { total: number; limit: number; offset: number };
};
export async function listApprovals(
req: Request,
res: Response,
next: NextFunction
) {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedQuery = querySchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error).toString()
)
);
}
const { limit, offset, approvalState } = parsedQuery.data;
const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const approvalsList = await queryApprovals(
orgId.toString(),
limit,
offset,
approvalState
);
const [{ count }] = await db
.select({ count: sql<number>`count(*)` })
.from(approvals);
return response<ListApprovalsResponse>(res, {
data: {
approvals: approvalsList,
pagination: {
total: count,
limit,
offset
}
},
success: true,
error: false,
message: "Approvals retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,142 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { build } from "@server/build";
import { approvals, clients, db, orgs, type Approval } from "@server/db";
import { getOrgTierData } from "@server/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import response from "@server/lib/response";
import { and, eq, type InferInsertModel } from "drizzle-orm";
import type { NextFunction, Request, Response } from "express";
const paramsSchema = z.strictObject({
orgId: z.string(),
approvalId: z.string().transform(Number).pipe(z.int().positive())
});
const bodySchema = z.strictObject({
decision: z.enum(["approved", "denied"])
});
export type ProcessApprovalResponse = Approval;
export async function processPendingApproval(
req: Request,
res: Response,
next: NextFunction
) {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { orgId, approvalId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const updateData = parsedBody.data;
const approval = await db
.select()
.from(approvals)
.where(
and(
eq(approvals.approvalId, approvalId),
eq(approvals.decision, "pending")
)
)
.innerJoin(orgs, eq(approvals.orgId, approvals.orgId))
.limit(1);
if (approval.length === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Pending Approval with ID ${approvalId} not found`
)
);
}
const [updatedApproval] = await db
.update(approvals)
.set(updateData)
.where(eq(approvals.approvalId, approvalId))
.returning();
// Update user device approval state too
if (
updatedApproval.type === "user_device" &&
updatedApproval.clientId
) {
const updateDataBody: Partial<InferInsertModel<typeof clients>> = {
approvalState: updateData.decision
};
if (updateData.decision === "denied") {
updateDataBody.blocked = true;
}
await db
.update(clients)
.set(updateDataBody)
.where(eq(clients.clientId, updatedApproval.clientId));
}
return response(res, {
data: updatedApproval,
success: true,
error: false,
message: "Approval updated successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -24,6 +24,7 @@ import * as generateLicense from "./generatedLicense";
import * as logs from "#private/routers/auditLogs";
import * as misc from "#private/routers/misc";
import * as reKey from "#private/routers/re-key";
import * as approval from "#private/routers/approvals";
import {
verifyOrgAccess,
@@ -311,6 +312,24 @@ authenticated.get(
loginPage.getLoginPage
);
authenticated.get(
"/org/:orgId/approvals",
verifyValidLicense,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.listApprovals),
logActionAudit(ActionsEnum.listApprovals),
approval.listApprovals
);
authenticated.put(
"/org/:orgId/approvals/:approvalId",
verifyValidLicense,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.updateApprovals),
logActionAudit(ActionsEnum.updateApprovals),
approval.processPendingApproval
);
authenticated.get(
"/org/:orgId/login-page-branding",
verifyValidLicense,

View File

@@ -29,11 +29,9 @@ import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
const paramsSchema = z
.object({
orgId: z.string()
})
.strict();
const paramsSchema = z.strictObject({
orgId: z.string()
});
export async function getLoginPageBranding(
req: Request,

View File

@@ -73,7 +73,7 @@ export async function blockClient(
// Block the client
await trx
.update(clients)
.set({ blocked: true })
.set({ blocked: true, approvalState: "denied" })
.where(eq(clients.clientId, clientId));
// Send terminate signal if there's an associated OLM and it's connected

View File

@@ -139,6 +139,7 @@ function queryClients(
userEmail: users.email,
niceId: clients.niceId,
agent: olms.agent,
approvalState: clients.approvalState,
olmArchived: olms.archived,
archived: clients.archived,
blocked: clients.blocked,

View File

@@ -71,7 +71,7 @@ export async function unblockClient(
// Unblock the client
await db
.update(clients)
.set({ blocked: false })
.set({ blocked: false, approvalState: null })
.where(eq(clients.clientId, clientId));
return response(res, {

View File

@@ -586,6 +586,14 @@ authenticated.get(
verifyUserHasAction(ActionsEnum.listRoles),
role.listRoles
);
authenticated.post(
"/org/:orgId/role/:roleId",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.updateRole),
logActionAudit(ActionsEnum.updateRole),
role.updateRole
);
// authenticated.get(
// "/role/:roleId",
// verifyRoleAccess,

View File

@@ -467,6 +467,14 @@ authenticated.put(
role.createRole
);
authenticated.post(
"/org/:orgId/role/:roleId",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.updateRole),
logActionAudit(ActionsEnum.updateRole),
role.updateRole
);
authenticated.get(
"/org/:orgId/roles",
verifyApiKeyOrgAccess,

View File

@@ -10,6 +10,8 @@ import { fromError } from "zod-validation-error";
import { ActionsEnum } from "@server/auth/actions";
import { eq, and } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { build } from "@server/build";
import { isLicensedOrSubscribed } from "@server/lib/isLicencedOrSubscribed";
const createRoleParamsSchema = z.strictObject({
orgId: z.string()
@@ -17,7 +19,8 @@ const createRoleParamsSchema = z.strictObject({
const createRoleSchema = z.strictObject({
name: z.string().min(1).max(255),
description: z.string().optional()
description: z.string().optional(),
requireDeviceApproval: z.boolean().optional()
});
export const defaultRoleAllowedActions: ActionsEnum[] = [
@@ -97,6 +100,11 @@ export async function createRole(
);
}
const isLicensed = await isLicensedOrSubscribed(orgId);
if (build === "oss" || !isLicensed) {
roleData.requireDeviceApproval = undefined;
}
await db.transaction(async (trx) => {
const newRole = await trx
.insert(roles)

View File

@@ -1,15 +1,13 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { roles, orgs } from "@server/db";
import { db, orgs, roles } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { sql, eq } from "drizzle-orm";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import stoi from "@server/lib/stoi";
import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode";
import { eq, sql } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
const listRolesParamsSchema = z.strictObject({
orgId: z.string()
@@ -38,7 +36,8 @@ async function queryRoles(orgId: string, limit: number, offset: number) {
isAdmin: roles.isAdmin,
name: roles.name,
description: roles.description,
orgName: orgs.name
orgName: orgs.name,
requireDeviceApproval: roles.requireDeviceApproval
})
.from(roles)
.leftJoin(orgs, eq(roles.orgId, orgs.orgId))

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { db, orgs, type Role } from "@server/db";
import { roles } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
@@ -8,20 +8,28 @@ import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { build } from "@server/build";
import { isLicensedOrSubscribed } from "@server/lib/isLicencedOrSubscribed";
const updateRoleParamsSchema = z.strictObject({
orgId: z.string(),
roleId: z.string().transform(Number).pipe(z.int().positive())
});
const updateRoleBodySchema = z
.strictObject({
name: z.string().min(1).max(255).optional(),
description: z.string().optional()
description: z.string().optional(),
requireDeviceApproval: z.boolean().optional()
})
.refine((data) => Object.keys(data).length > 0, {
error: "At least one field must be provided for update"
});
export type UpdateRoleBody = z.infer<typeof updateRoleBodySchema>;
export type UpdateRoleResponse = Role;
export async function updateRole(
req: Request,
res: Response,
@@ -48,13 +56,14 @@ export async function updateRole(
);
}
const { roleId } = parsedParams.data;
const { roleId, orgId } = parsedParams.data;
const updateData = parsedBody.data;
const role = await db
.select()
.from(roles)
.where(eq(roles.roleId, roleId))
.innerJoin(orgs, eq(roles.orgId, orgs.orgId))
.limit(1);
if (role.length === 0) {
@@ -66,7 +75,7 @@ export async function updateRole(
);
}
if (role[0].isAdmin) {
if (role[0].roles.isAdmin) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
@@ -75,6 +84,11 @@ export async function updateRole(
);
}
const isLicensed = await isLicensedOrSubscribed(orgId);
if (build === "oss" || !isLicensed) {
updateData.requireDeviceApproval = undefined;
}
const updatedRole = await db
.update(roles)
.set(updateData)