add site targets, client resources, and auto login

This commit is contained in:
miloschwartz
2025-08-14 18:24:21 -07:00
parent 67ba225003
commit 5c04b1e14a
80 changed files with 5651 additions and 2385 deletions

View File

@@ -69,6 +69,11 @@ export enum ActionsEnum {
deleteResourceRule = "deleteResourceRule",
listResourceRules = "listResourceRules",
updateResourceRule = "updateResourceRule",
createSiteResource = "createSiteResource",
deleteSiteResource = "deleteSiteResource",
getSiteResource = "getSiteResource",
listSiteResources = "listSiteResources",
updateSiteResource = "updateSiteResource",
createClient = "createClient",
deleteClient = "deleteClient",
updateClient = "updateClient",

View File

@@ -66,11 +66,6 @@ export const sites = pgTable("sites", {
export const resources = pgTable("resources", {
resourceId: serial("resourceId").primaryKey(),
siteId: integer("siteId")
.references(() => sites.siteId, {
onDelete: "cascade"
})
.notNull(),
orgId: varchar("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
@@ -97,6 +92,9 @@ export const resources = pgTable("resources", {
tlsServerName: varchar("tlsServerName"),
setHostHeader: varchar("setHostHeader"),
enableProxy: boolean("enableProxy").default(true),
skipToIdpId: integer("skipToIdpId").references(() => idp.idpId, {
onDelete: "cascade"
}),
});
export const targets = pgTable("targets", {
@@ -106,6 +104,11 @@ export const targets = pgTable("targets", {
onDelete: "cascade"
})
.notNull(),
siteId: integer("siteId")
.references(() => sites.siteId, {
onDelete: "cascade"
})
.notNull(),
ip: varchar("ip").notNull(),
method: varchar("method"),
port: integer("port").notNull(),
@@ -124,6 +127,22 @@ export const exitNodes = pgTable("exitNodes", {
maxConnections: integer("maxConnections")
});
export const siteResources = pgTable("siteResources", { // this is for the clients
siteResourceId: serial("siteResourceId").primaryKey(),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
orgId: varchar("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
name: varchar("name").notNull(),
protocol: varchar("protocol").notNull(),
proxyPort: integer("proxyPort").notNull(),
destinationPort: integer("destinationPort").notNull(),
destinationIp: varchar("destinationIp").notNull(),
enabled: boolean("enabled").notNull().default(true),
});
export const users = pgTable("user", {
userId: varchar("id").primaryKey(),
email: varchar("email"),
@@ -647,4 +666,5 @@ export type OlmSession = InferSelectModel<typeof olmSessions>;
export type UserClient = InferSelectModel<typeof userClients>;
export type RoleClient = InferSelectModel<typeof roleClients>;
export type OrgDomains = InferSelectModel<typeof orgDomains>;
export type SiteResource = InferSelectModel<typeof siteResources>;
export type SetupToken = InferSelectModel<typeof setupTokens>;

View File

@@ -67,16 +67,11 @@ export const sites = sqliteTable("sites", {
dockerSocketEnabled: integer("dockerSocketEnabled", { mode: "boolean" })
.notNull()
.default(true),
remoteSubnets: text("remoteSubnets"), // comma-separated list of subnets that this site can access
remoteSubnets: text("remoteSubnets") // comma-separated list of subnets that this site can access
});
export const resources = sqliteTable("resources", {
resourceId: integer("resourceId").primaryKey({ autoIncrement: true }),
siteId: integer("siteId")
.references(() => sites.siteId, {
onDelete: "cascade"
})
.notNull(),
orgId: text("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
@@ -109,6 +104,9 @@ export const resources = sqliteTable("resources", {
tlsServerName: text("tlsServerName"),
setHostHeader: text("setHostHeader"),
enableProxy: integer("enableProxy", { mode: "boolean" }).default(true),
skipToIdpId: integer("skipToIdpId").references(() => idp.idpId, {
onDelete: "cascade"
}),
});
export const targets = sqliteTable("targets", {
@@ -118,6 +116,11 @@ export const targets = sqliteTable("targets", {
onDelete: "cascade"
})
.notNull(),
siteId: integer("siteId")
.references(() => sites.siteId, {
onDelete: "cascade"
})
.notNull(),
ip: text("ip").notNull(),
method: text("method"),
port: integer("port").notNull(),
@@ -136,6 +139,22 @@ export const exitNodes = sqliteTable("exitNodes", {
maxConnections: integer("maxConnections")
});
export const siteResources = sqliteTable("siteResources", { // this is for the clients
siteResourceId: integer("siteResourceId").primaryKey({ autoIncrement: true }),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
name: text("name").notNull(),
protocol: text("protocol").notNull(),
proxyPort: integer("proxyPort").notNull(),
destinationPort: integer("destinationPort").notNull(),
destinationIp: text("destinationIp").notNull(),
enabled: integer("enabled", { mode: "boolean" }).notNull().default(true),
});
export const users = sqliteTable("user", {
userId: text("id").primaryKey(),
email: text("email"),
@@ -166,9 +185,11 @@ export const users = sqliteTable("user", {
export const securityKeys = sqliteTable("webauthnCredentials", {
credentialId: text("credentialId").primaryKey(),
userId: text("userId").notNull().references(() => users.userId, {
onDelete: "cascade"
}),
userId: text("userId")
.notNull()
.references(() => users.userId, {
onDelete: "cascade"
}),
publicKey: text("publicKey").notNull(),
signCount: integer("signCount").notNull(),
transports: text("transports"),
@@ -688,6 +709,7 @@ export type Idp = InferSelectModel<typeof idp>;
export type ApiKey = InferSelectModel<typeof apiKeys>;
export type ApiKeyAction = InferSelectModel<typeof apiKeyActions>;
export type ApiKeyOrg = InferSelectModel<typeof apiKeyOrg>;
export type SiteResource = InferSelectModel<typeof siteResources>;
export type OrgDomains = InferSelectModel<typeof orgDomains>;
export type SetupToken = InferSelectModel<typeof setupTokens>;
export type HostMeta = InferSelectModel<typeof hostMeta>;

View File

@@ -27,3 +27,4 @@ export * from "./verifyApiKeyAccess";
export * from "./verifyDomainAccess";
export * from "./verifyClientsEnabled";
export * from "./verifyUserIsOrgOwner";
export * from "./verifySiteResourceAccess";

View File

@@ -0,0 +1,62 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { siteResources } from "@server/db";
import { eq, and } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
export async function verifySiteResourceAccess(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const siteResourceId = parseInt(req.params.siteResourceId);
const siteId = parseInt(req.params.siteId);
const orgId = req.params.orgId;
if (!siteResourceId || !siteId || !orgId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Missing required parameters"
)
);
}
// Check if the site resource exists and belongs to the specified site and org
const [siteResource] = await db
.select()
.from(siteResources)
.where(and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
))
.limit(1);
if (!siteResource) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Site resource not found"
)
);
}
// Attach the siteResource to the request for use in the next middleware/route
// @ts-ignore - Extending Request type
req.siteResource = siteResource;
next();
} catch (error) {
logger.error("Error verifying site resource access:", error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying site resource access"
)
);
}
}

View File

@@ -0,0 +1,39 @@
import { sendToClient } from "../ws";
export async function addTargets(
newtId: string,
destinationIp: string,
destinationPort: number,
protocol: string,
port: number | null = null
) {
const target = `${port ? port + ":" : ""}${
destinationIp
}:${destinationPort}`;
await sendToClient(newtId, {
type: `newt/wg/${protocol}/add`,
data: {
targets: [target] // We can only use one target for WireGuard right now
}
});
}
export async function removeTargets(
newtId: string,
destinationIp: string,
destinationPort: number,
protocol: string,
port: number | null = null
) {
const target = `${port ? port + ":" : ""}${
destinationIp
}:${destinationPort}`;
await sendToClient(newtId, {
type: `newt/wg/${protocol}/remove`,
data: {
targets: [target] // We can only use one target for WireGuard right now
}
});
}

View File

@@ -9,6 +9,7 @@ import * as user from "./user";
import * as auth from "./auth";
import * as role from "./role";
import * as client from "./client";
import * as siteResource from "./siteResource";
import * as supporterKey from "./supporterKey";
import * as accessToken from "./accessToken";
import * as idp from "./idp";
@@ -34,7 +35,8 @@ import {
verifyDomainAccess,
verifyClientsEnabled,
verifyUserHasAction,
verifyUserIsOrgOwner
verifyUserIsOrgOwner,
verifySiteResourceAccess
} from "@server/middlewares";
import { createStore } from "@server/lib/rateLimitStore";
import { ActionsEnum } from "@server/auth/actions";
@@ -213,9 +215,60 @@ authenticated.get(
site.listContainers
);
// Site Resource endpoints
authenticated.put(
"/org/:orgId/site/:siteId/resource",
verifyOrgAccess,
verifySiteAccess,
verifyUserHasAction(ActionsEnum.createSiteResource),
siteResource.createSiteResource
);
authenticated.get(
"/org/:orgId/site/:siteId/resources",
verifyOrgAccess,
verifySiteAccess,
verifyUserHasAction(ActionsEnum.listSiteResources),
siteResource.listSiteResources
);
authenticated.get(
"/org/:orgId/site-resources",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.listSiteResources),
siteResource.listAllSiteResourcesByOrg
);
authenticated.get(
"/org/:orgId/site/:siteId/resource/:siteResourceId",
verifyOrgAccess,
verifySiteAccess,
verifySiteResourceAccess,
verifyUserHasAction(ActionsEnum.getSiteResource),
siteResource.getSiteResource
);
authenticated.post(
"/org/:orgId/site/:siteId/resource/:siteResourceId",
verifyOrgAccess,
verifySiteAccess,
verifySiteResourceAccess,
verifyUserHasAction(ActionsEnum.updateSiteResource),
siteResource.updateSiteResource
);
authenticated.delete(
"/org/:orgId/site/:siteId/resource/:siteResourceId",
verifyOrgAccess,
verifySiteAccess,
verifySiteResourceAccess,
verifyUserHasAction(ActionsEnum.deleteSiteResource),
siteResource.deleteSiteResource
);
authenticated.put(
"/org/:orgId/resource",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createResource),
resource.createResource
);
@@ -397,28 +450,6 @@ authenticated.post(
user.addUserRole
);
// authenticated.put(
// "/role/:roleId/site",
// verifyRoleAccess,
// verifyUserInRole,
// verifyUserHasAction(ActionsEnum.addRoleSite),
// role.addRoleSite
// );
// authenticated.delete(
// "/role/:roleId/site",
// verifyRoleAccess,
// verifyUserInRole,
// verifyUserHasAction(ActionsEnum.removeRoleSite),
// role.removeRoleSite
// );
// authenticated.get(
// "/role/:roleId/sites",
// verifyRoleAccess,
// verifyUserInRole,
// verifyUserHasAction(ActionsEnum.listRoleSites),
// role.listRoleSites
// );
authenticated.post(
"/resource/:resourceId/roles",
verifyResourceAccess,
@@ -463,13 +494,6 @@ authenticated.get(
resource.getResourceWhitelist
);
authenticated.post(
`/resource/:resourceId/transfer`,
verifyResourceAccess,
verifyUserHasAction(ActionsEnum.updateResource),
resource.transferResource
);
authenticated.post(
`/resource/:resourceId/access-token`,
verifyResourceAccess,

View File

@@ -341,13 +341,6 @@ authenticated.get(
resource.getResourceWhitelist
);
authenticated.post(
`/resource/:resourceId/transfer`,
verifyApiKeyResourceAccess,
verifyApiKeyHasAction(ActionsEnum.updateResource),
resource.transferResource
);
authenticated.post(
`/resource/:resourceId/access-token`,
verifyApiKeyResourceAccess,

View File

@@ -220,78 +220,37 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
// Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null);
// Improved version
const allResources = await db.transaction(async (tx) => {
// First get all resources for the site
const resourcesList = await tx
.select({
resourceId: resources.resourceId,
subdomain: resources.subdomain,
fullDomain: resources.fullDomain,
ssl: resources.ssl,
blockAccess: resources.blockAccess,
sso: resources.sso,
emailWhitelistEnabled: resources.emailWhitelistEnabled,
http: resources.http,
proxyPort: resources.proxyPort,
protocol: resources.protocol
})
.from(resources)
.where(and(eq(resources.siteId, siteId), eq(resources.http, false)));
// Get all enabled targets with their resource protocol information
const allTargets = await db
.select({
resourceId: targets.resourceId,
targetId: targets.targetId,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled,
protocol: resources.protocol
})
.from(targets)
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
// Get all enabled targets for these resources in a single query
const resourceIds = resourcesList.map((r) => r.resourceId);
const allTargets =
resourceIds.length > 0
? await tx
.select({
resourceId: targets.resourceId,
targetId: targets.targetId,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled,
})
.from(targets)
.where(
and(
inArray(targets.resourceId, resourceIds),
eq(targets.enabled, true)
)
)
: [];
const { tcpTargets, udpTargets } = allTargets.reduce(
(acc, target) => {
// Filter out invalid targets
if (!target.internalPort || !target.ip || !target.port) {
return acc;
}
// Combine the data in JS instead of using SQL for the JSON
return resourcesList.map((resource) => ({
...resource,
targets: allTargets.filter(
(target) => target.resourceId === resource.resourceId
)
}));
});
const { tcpTargets, udpTargets } = allResources.reduce(
(acc, resource) => {
// Skip resources with no targets
if (!resource.targets?.length) return acc;
// Format valid targets into strings
const formattedTargets = resource.targets
.filter(
(target: Target) =>
resource.proxyPort && target?.ip && target?.port
)
.map(
(target: Target) =>
`${resource.proxyPort}:${target.ip}:${target.port}`
);
// Format target into string
const formattedTarget = `${target.internalPort}:${target.ip}:${target.port}`;
// Add to the appropriate protocol array
if (resource.protocol === "tcp") {
acc.tcpTargets.push(...formattedTargets);
if (target.protocol === "tcp") {
acc.tcpTargets.push(formattedTarget);
} else {
acc.udpTargets.push(...formattedTargets);
acc.udpTargets.push(formattedTarget);
}
return acc;

View File

@@ -105,7 +105,9 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
.limit(1);
const blockSize = config.getRawConfig().gerbil.site_block_size;
const subnets = sitesQuery.map((site) => site.subnet).filter((subnet) => subnet !== null);
const subnets = sitesQuery
.map((site) => site.subnet)
.filter((subnet) => subnet !== null);
subnets.push(exitNode.address.replace(/\/\d+$/, `/${blockSize}`));
const newSubnet = findNextAvailableCidr(
subnets,
@@ -160,78 +162,37 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
allowedIps: [siteSubnet]
});
// Improved version
const allResources = await db.transaction(async (tx) => {
// First get all resources for the site
const resourcesList = await tx
.select({
resourceId: resources.resourceId,
subdomain: resources.subdomain,
fullDomain: resources.fullDomain,
ssl: resources.ssl,
blockAccess: resources.blockAccess,
sso: resources.sso,
emailWhitelistEnabled: resources.emailWhitelistEnabled,
http: resources.http,
proxyPort: resources.proxyPort,
protocol: resources.protocol
})
.from(resources)
.where(eq(resources.siteId, siteId));
// Get all enabled targets with their resource protocol information
const allTargets = await db
.select({
resourceId: targets.resourceId,
targetId: targets.targetId,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled,
protocol: resources.protocol
})
.from(targets)
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
// Get all enabled targets for these resources in a single query
const resourceIds = resourcesList.map((r) => r.resourceId);
const allTargets =
resourceIds.length > 0
? await tx
.select({
resourceId: targets.resourceId,
targetId: targets.targetId,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled
})
.from(targets)
.where(
and(
inArray(targets.resourceId, resourceIds),
eq(targets.enabled, true)
)
)
: [];
const { tcpTargets, udpTargets } = allTargets.reduce(
(acc, target) => {
// Filter out invalid targets
if (!target.internalPort || !target.ip || !target.port) {
return acc;
}
// Combine the data in JS instead of using SQL for the JSON
return resourcesList.map((resource) => ({
...resource,
targets: allTargets.filter(
(target) => target.resourceId === resource.resourceId
)
}));
});
const { tcpTargets, udpTargets } = allResources.reduce(
(acc, resource) => {
// Skip resources with no targets
if (!resource.targets?.length) return acc;
// Format valid targets into strings
const formattedTargets = resource.targets
.filter(
(target: Target) =>
target?.internalPort && target?.ip && target?.port
)
.map(
(target: Target) =>
`${target.internalPort}:${target.ip}:${target.port}`
);
// Format target into string
const formattedTarget = `${target.internalPort}:${target.ip}:${target.port}`;
// Add to the appropriate protocol array
if (resource.protocol === "tcp") {
acc.tcpTargets.push(...formattedTargets);
if (target.protocol === "tcp") {
acc.tcpTargets.push(formattedTarget);
} else {
acc.udpTargets.push(...formattedTargets);
acc.udpTargets.push(formattedTarget);
}
return acc;

View File

@@ -1,7 +1,8 @@
import { Target } from "@server/db";
import { sendToClient } from "../ws";
import logger from "@server/logger";
export function addTargets(
export async function addTargets(
newtId: string,
targets: Target[],
protocol: string,
@@ -20,22 +21,9 @@ export function addTargets(
targets: payloadTargets
}
});
const payloadTargetsResources = targets.map((target) => {
return `${port ? port + ":" : ""}${
target.ip
}:${target.port}`;
});
sendToClient(newtId, {
type: `newt/wg/${protocol}/add`,
data: {
targets: [payloadTargetsResources[0]] // We can only use one target for WireGuard right now
}
});
}
export function removeTargets(
export async function removeTargets(
newtId: string,
targets: Target[],
protocol: string,
@@ -48,23 +36,10 @@ export function removeTargets(
}:${target.port}`;
});
sendToClient(newtId, {
await sendToClient(newtId, {
type: `newt/${protocol}/remove`,
data: {
targets: payloadTargets
}
});
const payloadTargetsResources = targets.map((target) => {
return `${port ? port + ":" : ""}${
target.ip
}:${target.port}`;
});
sendToClient(newtId, {
type: `newt/wg/${protocol}/remove`,
data: {
targets: [payloadTargetsResources[0]] // We can only use one target for WireGuard right now
}
});
}

View File

@@ -15,7 +15,6 @@ import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { eq, and } from "drizzle-orm";
import stoi from "@server/lib/stoi";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { subdomainSchema } from "@server/lib/schemas";
@@ -25,7 +24,6 @@ import { build } from "@server/build";
const createResourceParamsSchema = z
.object({
siteId: z.string().transform(stoi).pipe(z.number().int().positive()),
orgId: z.string()
})
.strict();
@@ -34,7 +32,6 @@ const createHttpResourceSchema = z
.object({
name: z.string().min(1).max(255),
subdomain: z.string().nullable().optional(),
siteId: z.number(),
http: z.boolean(),
protocol: z.enum(["tcp", "udp"]),
domainId: z.string()
@@ -53,11 +50,10 @@ const createHttpResourceSchema = z
const createRawResourceSchema = z
.object({
name: z.string().min(1).max(255),
siteId: z.number(),
http: z.boolean(),
protocol: z.enum(["tcp", "udp"]),
proxyPort: z.number().int().min(1).max(65535),
enableProxy: z.boolean().default(true)
// enableProxy: z.boolean().default(true) // always true now
})
.strict()
.refine(
@@ -78,7 +74,7 @@ export type CreateResourceResponse = Resource;
registry.registerPath({
method: "put",
path: "/org/{orgId}/site/{siteId}/resource",
path: "/org/{orgId}/resource",
description: "Create a resource.",
tags: [OpenAPITags.Org, OpenAPITags.Resource],
request: {
@@ -111,7 +107,7 @@ export async function createResource(
);
}
const { siteId, orgId } = parsedParams.data;
const { orgId } = parsedParams.data;
if (req.user && !req.userOrgRoleId) {
return next(
@@ -146,7 +142,7 @@ export async function createResource(
if (http) {
return await createHttpResource(
{ req, res, next },
{ siteId, orgId }
{ orgId }
);
} else {
if (
@@ -162,7 +158,7 @@ export async function createResource(
}
return await createRawResource(
{ req, res, next },
{ siteId, orgId }
{ orgId }
);
}
} catch (error) {
@@ -180,12 +176,11 @@ async function createHttpResource(
next: NextFunction;
},
meta: {
siteId: number;
orgId: string;
}
) {
const { req, res, next } = route;
const { siteId, orgId } = meta;
const { orgId } = meta;
const parsedBody = createHttpResourceSchema.safeParse(req.body);
if (!parsedBody.success) {
@@ -292,7 +287,6 @@ async function createHttpResource(
const newResource = await trx
.insert(resources)
.values({
siteId,
fullDomain,
domainId,
orgId,
@@ -357,12 +351,11 @@ async function createRawResource(
next: NextFunction;
},
meta: {
siteId: number;
orgId: string;
}
) {
const { req, res, next } = route;
const { siteId, orgId } = meta;
const { orgId } = meta;
const parsedBody = createRawResourceSchema.safeParse(req.body);
if (!parsedBody.success) {
@@ -374,7 +367,7 @@ async function createRawResource(
);
}
const { name, http, protocol, proxyPort, enableProxy } = parsedBody.data;
const { name, http, protocol, proxyPort } = parsedBody.data;
// if http is false check to see if there is already a resource with the same port and protocol
const existingResource = await db
@@ -402,13 +395,12 @@ async function createRawResource(
const newResource = await trx
.insert(resources)
.values({
siteId,
orgId,
name,
http,
protocol,
proxyPort,
enableProxy
// enableProxy
})
.returning();

View File

@@ -71,44 +71,44 @@ export async function deleteResource(
);
}
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, deletedResource.siteId!))
.limit(1);
if (!site) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${deletedResource.siteId} not found`
)
);
}
if (site.pubKey) {
if (site.type == "wireguard") {
await addPeer(site.exitNodeId!, {
publicKey: site.pubKey,
allowedIps: await getAllowedIps(site.siteId)
});
} else if (site.type == "newt") {
// get the newt on the site by querying the newt table for siteId
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
removeTargets(
newt.newtId,
targetsToBeRemoved,
deletedResource.protocol,
deletedResource.proxyPort
);
}
}
// const [site] = await db
// .select()
// .from(sites)
// .where(eq(sites.siteId, deletedResource.siteId!))
// .limit(1);
//
// if (!site) {
// return next(
// createHttpError(
// HttpCode.NOT_FOUND,
// `Site with ID ${deletedResource.siteId} not found`
// )
// );
// }
//
// if (site.pubKey) {
// if (site.type == "wireguard") {
// await addPeer(site.exitNodeId!, {
// publicKey: site.pubKey,
// allowedIps: await getAllowedIps(site.siteId)
// });
// } else if (site.type == "newt") {
// // get the newt on the site by querying the newt table for siteId
// const [newt] = await db
// .select()
// .from(newts)
// .where(eq(newts.siteId, site.siteId))
// .limit(1);
//
// removeTargets(
// newt.newtId,
// targetsToBeRemoved,
// deletedResource.protocol,
// deletedResource.proxyPort
// );
// }
// }
//
return response(res, {
data: null,
success: true,

View File

@@ -19,9 +19,7 @@ const getResourceSchema = z
})
.strict();
export type GetResourceResponse = Resource & {
siteName: string;
};
export type GetResourceResponse = Resource;
registry.registerPath({
method: "get",
@@ -56,11 +54,9 @@ export async function getResource(
.select()
.from(resources)
.where(eq(resources.resourceId, resourceId))
.leftJoin(sites, eq(sites.siteId, resources.siteId))
.limit(1);
const resource = resp.resources;
const site = resp.sites;
const resource = resp;
if (!resource) {
return next(
@@ -73,8 +69,7 @@ export async function getResource(
return response(res, {
data: {
...resource,
siteName: site?.name
...resource
},
success: true,
error: false,

View File

@@ -31,6 +31,7 @@ export type GetResourceAuthInfoResponse = {
blockAccess: boolean;
url: string;
whitelist: boolean;
skipToIdpId: number | null;
};
export async function getResourceAuthInfo(
@@ -86,7 +87,8 @@ export async function getResourceAuthInfo(
sso: resource.sso,
blockAccess: resource.blockAccess,
url,
whitelist: resource.emailWhitelistEnabled
whitelist: resource.emailWhitelistEnabled,
skipToIdpId: resource.skipToIdpId
},
success: true,
error: false,

View File

@@ -1,16 +1,14 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { and, eq, or, inArray } from "drizzle-orm";
import {
resources,
userResources,
roleResources,
userOrgs,
roles,
import {
resources,
userResources,
roleResources,
userOrgs,
resourcePassword,
resourcePincode,
resourceWhitelist,
sites
resourceWhitelist
} from "@server/db";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
@@ -37,12 +35,7 @@ export async function getUserResources(
roleId: userOrgs.roleId
})
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
eq(userOrgs.orgId, orgId)
)
)
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)))
.limit(1);
if (userOrgResult.length === 0) {
@@ -71,8 +64,8 @@ export async function getUserResources(
// Combine all accessible resource IDs
const accessibleResourceIds = [
...directResources.map(r => r.resourceId),
...roleResourceResults.map(r => r.resourceId)
...directResources.map((r) => r.resourceId),
...roleResourceResults.map((r) => r.resourceId)
];
if (accessibleResourceIds.length === 0) {
@@ -95,11 +88,9 @@ export async function getUserResources(
enabled: resources.enabled,
sso: resources.sso,
protocol: resources.protocol,
emailWhitelistEnabled: resources.emailWhitelistEnabled,
siteName: sites.name
emailWhitelistEnabled: resources.emailWhitelistEnabled
})
.from(resources)
.leftJoin(sites, eq(sites.siteId, resources.siteId))
.where(
and(
inArray(resources.resourceId, accessibleResourceIds),
@@ -111,28 +102,61 @@ export async function getUserResources(
// Check for password, pincode, and whitelist protection for each resource
const resourcesWithAuth = await Promise.all(
resourcesData.map(async (resource) => {
const [passwordCheck, pincodeCheck, whitelistCheck] = await Promise.all([
db.select().from(resourcePassword).where(eq(resourcePassword.resourceId, resource.resourceId)).limit(1),
db.select().from(resourcePincode).where(eq(resourcePincode.resourceId, resource.resourceId)).limit(1),
db.select().from(resourceWhitelist).where(eq(resourceWhitelist.resourceId, resource.resourceId)).limit(1)
]);
const [passwordCheck, pincodeCheck, whitelistCheck] =
await Promise.all([
db
.select()
.from(resourcePassword)
.where(
eq(
resourcePassword.resourceId,
resource.resourceId
)
)
.limit(1),
db
.select()
.from(resourcePincode)
.where(
eq(
resourcePincode.resourceId,
resource.resourceId
)
)
.limit(1),
db
.select()
.from(resourceWhitelist)
.where(
eq(
resourceWhitelist.resourceId,
resource.resourceId
)
)
.limit(1)
]);
const hasPassword = passwordCheck.length > 0;
const hasPincode = pincodeCheck.length > 0;
const hasWhitelist = whitelistCheck.length > 0 || resource.emailWhitelistEnabled;
const hasWhitelist =
whitelistCheck.length > 0 || resource.emailWhitelistEnabled;
return {
resourceId: resource.resourceId,
name: resource.name,
domain: `${resource.ssl ? "https://" : "http://"}${resource.fullDomain}`,
enabled: resource.enabled,
protected: !!(resource.sso || hasPassword || hasPincode || hasWhitelist),
protected: !!(
resource.sso ||
hasPassword ||
hasPincode ||
hasWhitelist
),
protocol: resource.protocol,
sso: resource.sso,
password: hasPassword,
pincode: hasPincode,
whitelist: hasWhitelist,
siteName: resource.siteName
whitelist: hasWhitelist
};
})
);
@@ -144,11 +168,13 @@ export async function getUserResources(
message: "User resources retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
console.error("Error fetching user resources:", error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Internal server error")
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Internal server error"
)
);
}
}
@@ -165,4 +191,4 @@ export type GetUserResourcesResponse = {
protocol: string;
}>;
};
};
};

View File

@@ -16,10 +16,9 @@ export * from "./setResourceWhitelist";
export * from "./getResourceWhitelist";
export * from "./authWithWhitelist";
export * from "./authWithAccessToken";
export * from "./transferResource";
export * from "./getExchangeToken";
export * from "./createResourceRule";
export * from "./deleteResourceRule";
export * from "./listResourceRules";
export * from "./updateResourceRule";
export * from "./getUserResources";
export * from "./getUserResources";

View File

@@ -3,7 +3,6 @@ import { z } from "zod";
import { db } from "@server/db";
import {
resources,
sites,
userResources,
roleResources,
resourcePassword,
@@ -20,17 +19,9 @@ import { OpenAPITags, registry } from "@server/openApi";
const listResourcesParamsSchema = z
.object({
siteId: z
.string()
.optional()
.transform(stoi)
.pipe(z.number().int().positive().optional()),
orgId: z.string().optional()
orgId: z.string()
})
.strict()
.refine((data) => !!data.siteId !== !!data.orgId, {
message: "Either siteId or orgId must be provided, but not both"
});
.strict();
const listResourcesSchema = z.object({
limit: z
@@ -48,82 +39,38 @@ const listResourcesSchema = z.object({
.pipe(z.number().int().nonnegative())
});
function queryResources(
accessibleResourceIds: number[],
siteId?: number,
orgId?: string
) {
if (siteId) {
return db
.select({
resourceId: resources.resourceId,
name: resources.name,
fullDomain: resources.fullDomain,
ssl: resources.ssl,
siteName: sites.name,
siteId: sites.niceId,
passwordId: resourcePassword.passwordId,
pincodeId: resourcePincode.pincodeId,
sso: resources.sso,
whitelist: resources.emailWhitelistEnabled,
http: resources.http,
protocol: resources.protocol,
proxyPort: resources.proxyPort,
enabled: resources.enabled,
domainId: resources.domainId
})
.from(resources)
.leftJoin(sites, eq(resources.siteId, sites.siteId))
.leftJoin(
resourcePassword,
eq(resourcePassword.resourceId, resources.resourceId)
function queryResources(accessibleResourceIds: number[], orgId: string) {
return db
.select({
resourceId: resources.resourceId,
name: resources.name,
ssl: resources.ssl,
fullDomain: resources.fullDomain,
passwordId: resourcePassword.passwordId,
sso: resources.sso,
pincodeId: resourcePincode.pincodeId,
whitelist: resources.emailWhitelistEnabled,
http: resources.http,
protocol: resources.protocol,
proxyPort: resources.proxyPort,
enabled: resources.enabled,
domainId: resources.domainId
})
.from(resources)
.leftJoin(
resourcePassword,
eq(resourcePassword.resourceId, resources.resourceId)
)
.leftJoin(
resourcePincode,
eq(resourcePincode.resourceId, resources.resourceId)
)
.where(
and(
inArray(resources.resourceId, accessibleResourceIds),
eq(resources.orgId, orgId)
)
.leftJoin(
resourcePincode,
eq(resourcePincode.resourceId, resources.resourceId)
)
.where(
and(
inArray(resources.resourceId, accessibleResourceIds),
eq(resources.siteId, siteId)
)
);
} else if (orgId) {
return db
.select({
resourceId: resources.resourceId,
name: resources.name,
ssl: resources.ssl,
fullDomain: resources.fullDomain,
siteName: sites.name,
siteId: sites.niceId,
passwordId: resourcePassword.passwordId,
sso: resources.sso,
pincodeId: resourcePincode.pincodeId,
whitelist: resources.emailWhitelistEnabled,
http: resources.http,
protocol: resources.protocol,
proxyPort: resources.proxyPort,
enabled: resources.enabled,
domainId: resources.domainId
})
.from(resources)
.leftJoin(sites, eq(resources.siteId, sites.siteId))
.leftJoin(
resourcePassword,
eq(resourcePassword.resourceId, resources.resourceId)
)
.leftJoin(
resourcePincode,
eq(resourcePincode.resourceId, resources.resourceId)
)
.where(
and(
inArray(resources.resourceId, accessibleResourceIds),
eq(resources.orgId, orgId)
)
);
}
);
}
export type ListResourcesResponse = {
@@ -131,20 +78,6 @@ export type ListResourcesResponse = {
pagination: { total: number; limit: number; offset: number };
};
registry.registerPath({
method: "get",
path: "/site/{siteId}/resources",
description: "List resources for a site.",
tags: [OpenAPITags.Site, OpenAPITags.Resource],
request: {
params: z.object({
siteId: z.number()
}),
query: listResourcesSchema
},
responses: {}
});
registry.registerPath({
method: "get",
path: "/org/{orgId}/resources",
@@ -185,9 +118,11 @@ export async function listResources(
)
);
}
const { siteId } = parsedParams.data;
const orgId = parsedParams.data.orgId || req.userOrg?.orgId || req.apiKeyOrg?.orgId;
const orgId =
parsedParams.data.orgId ||
req.userOrg?.orgId ||
req.apiKeyOrg?.orgId;
if (!orgId) {
return next(
@@ -207,24 +142,27 @@ export async function listResources(
let accessibleResources;
if (req.user) {
accessibleResources = await db
.select({
resourceId: sql<number>`COALESCE(${userResources.resourceId}, ${roleResources.resourceId})`
})
.from(userResources)
.fullJoin(
roleResources,
eq(userResources.resourceId, roleResources.resourceId)
)
.where(
or(
eq(userResources.userId, req.user!.userId),
eq(roleResources.roleId, req.userOrgRoleId!)
.select({
resourceId: sql<number>`COALESCE(${userResources.resourceId}, ${roleResources.resourceId})`
})
.from(userResources)
.fullJoin(
roleResources,
eq(userResources.resourceId, roleResources.resourceId)
)
);
.where(
or(
eq(userResources.userId, req.user!.userId),
eq(roleResources.roleId, req.userOrgRoleId!)
)
);
} else {
accessibleResources = await db.select({
resourceId: resources.resourceId
}).from(resources).where(eq(resources.orgId, orgId));
accessibleResources = await db
.select({
resourceId: resources.resourceId
})
.from(resources)
.where(eq(resources.orgId, orgId));
}
const accessibleResourceIds = accessibleResources.map(
@@ -236,7 +174,7 @@ export async function listResources(
.from(resources)
.where(inArray(resources.resourceId, accessibleResourceIds));
const baseQuery = queryResources(accessibleResourceIds, siteId, orgId);
const baseQuery = queryResources(accessibleResourceIds, orgId);
const resourcesList = await baseQuery!.limit(limit).offset(offset);
const totalCountResult = await countQuery;

View File

@@ -1,214 +0,0 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { newts, resources, sites, targets } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { addPeer } from "../gerbil/peers";
import { addTargets, removeTargets } from "../newt/targets";
import { getAllowedIps } from "../target/helpers";
import { OpenAPITags, registry } from "@server/openApi";
const transferResourceParamsSchema = z
.object({
resourceId: z
.string()
.transform(Number)
.pipe(z.number().int().positive())
})
.strict();
const transferResourceBodySchema = z
.object({
siteId: z.number().int().positive()
})
.strict();
registry.registerPath({
method: "post",
path: "/resource/{resourceId}/transfer",
description:
"Transfer a resource to a different site. This will also transfer the targets associated with the resource.",
tags: [OpenAPITags.Resource],
request: {
params: transferResourceParamsSchema,
body: {
content: {
"application/json": {
schema: transferResourceBodySchema
}
}
}
},
responses: {}
});
export async function transferResource(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = transferResourceParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedBody = transferResourceBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { resourceId } = parsedParams.data;
const { siteId } = parsedBody.data;
const [oldResource] = await db
.select()
.from(resources)
.where(eq(resources.resourceId, resourceId))
.limit(1);
if (!oldResource) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Resource with ID ${resourceId} not found`
)
);
}
if (oldResource.siteId === siteId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Resource is already assigned to this site`
)
);
}
const [newSite] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!newSite) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${siteId} not found`
)
);
}
const [oldSite] = await db
.select()
.from(sites)
.where(eq(sites.siteId, oldResource.siteId))
.limit(1);
if (!oldSite) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${oldResource.siteId} not found`
)
);
}
const [updatedResource] = await db
.update(resources)
.set({ siteId })
.where(eq(resources.resourceId, resourceId))
.returning();
if (!updatedResource) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Resource with ID ${resourceId} not found`
)
);
}
const resourceTargets = await db
.select()
.from(targets)
.where(eq(targets.resourceId, resourceId));
if (resourceTargets.length > 0) {
////// REMOVE THE TARGETS FROM THE OLD SITE //////
if (oldSite.pubKey) {
if (oldSite.type == "wireguard") {
await addPeer(oldSite.exitNodeId!, {
publicKey: oldSite.pubKey,
allowedIps: await getAllowedIps(oldSite.siteId)
});
} else if (oldSite.type == "newt") {
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, oldSite.siteId))
.limit(1);
removeTargets(
newt.newtId,
resourceTargets,
updatedResource.protocol,
updatedResource.proxyPort
);
}
}
////// ADD THE TARGETS TO THE NEW SITE //////
if (newSite.pubKey) {
if (newSite.type == "wireguard") {
await addPeer(newSite.exitNodeId!, {
publicKey: newSite.pubKey,
allowedIps: await getAllowedIps(newSite.siteId)
});
} else if (newSite.type == "newt") {
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, newSite.siteId))
.limit(1);
addTargets(
newt.newtId,
resourceTargets,
updatedResource.protocol,
updatedResource.proxyPort
);
}
}
}
return response(res, {
data: updatedResource,
success: true,
error: false,
message: "Resource transferred successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -20,7 +20,6 @@ import { tlsNameSchema } from "@server/lib/schemas";
import { subdomainSchema } from "@server/lib/schemas";
import { registry } from "@server/openApi";
import { OpenAPITags } from "@server/openApi";
import { build } from "@server/build";
const updateResourceParamsSchema = z
.object({
@@ -44,7 +43,8 @@ const updateHttpResourceBodySchema = z
enabled: z.boolean().optional(),
stickySession: z.boolean().optional(),
tlsServerName: z.string().nullable().optional(),
setHostHeader: z.string().nullable().optional()
setHostHeader: z.string().nullable().optional(),
skipToIdpId: z.number().int().positive().nullable().optional()
})
.strict()
.refine((data) => Object.keys(data).length > 0, {
@@ -91,8 +91,8 @@ const updateRawResourceBodySchema = z
name: z.string().min(1).max(255).optional(),
proxyPort: z.number().int().min(1).max(65535).optional(),
stickySession: z.boolean().optional(),
enabled: z.boolean().optional(),
enableProxy: z.boolean().optional()
enabled: z.boolean().optional()
// enableProxy: z.boolean().optional() // always true now
})
.strict()
.refine((data) => Object.keys(data).length > 0, {

View File

@@ -60,18 +60,18 @@ export async function addRoleSite(
})
.returning();
const siteResources = await db
.select()
.from(resources)
.where(eq(resources.siteId, siteId));
for (const resource of siteResources) {
await trx.insert(roleResources).values({
roleId,
resourceId: resource.resourceId
});
}
// const siteResources = await db
// .select()
// .from(resources)
// .where(eq(resources.siteId, siteId));
//
// for (const resource of siteResources) {
// await trx.insert(roleResources).values({
// roleId,
// resourceId: resource.resourceId
// });
// }
//
return response(res, {
data: newRoleSite[0],
success: true,

View File

@@ -1,6 +1,5 @@
export * from "./addRoleAction";
export * from "../resource/setResourceRoles";
export * from "./addRoleSite";
export * from "./createRole";
export * from "./deleteRole";
export * from "./getRole";
@@ -11,5 +10,4 @@ export * from "./listRoles";
export * from "./listRoleSites";
export * from "./removeRoleAction";
export * from "./removeRoleResource";
export * from "./removeRoleSite";
export * from "./updateRole";
export * from "./updateRole";

View File

@@ -71,22 +71,22 @@ export async function removeRoleSite(
);
}
const siteResources = await db
.select()
.from(resources)
.where(eq(resources.siteId, siteId));
for (const resource of siteResources) {
await trx
.delete(roleResources)
.where(
and(
eq(roleResources.roleId, roleId),
eq(roleResources.resourceId, resource.resourceId)
)
)
.returning();
}
// const siteResources = await db
// .select()
// .from(resources)
// .where(eq(resources.siteId, siteId));
//
// for (const resource of siteResources) {
// await trx
// .delete(roleResources)
// .where(
// and(
// eq(roleResources.roleId, roleId),
// eq(roleResources.resourceId, resource.resourceId)
// )
// )
// .returning();
// }
});
return response(res, {

View File

@@ -0,0 +1,171 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, newts } from "@server/db";
import { siteResources, sites, orgs, SiteResource } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import { addTargets } from "../client/targets";
const createSiteResourceParamsSchema = z
.object({
siteId: z.string().transform(Number).pipe(z.number().int().positive()),
orgId: z.string()
})
.strict();
const createSiteResourceSchema = z
.object({
name: z.string().min(1).max(255),
protocol: z.enum(["tcp", "udp"]),
proxyPort: z.number().int().positive(),
destinationPort: z.number().int().positive(),
destinationIp: z.string().ip(),
enabled: z.boolean().default(true)
})
.strict();
export type CreateSiteResourceBody = z.infer<typeof createSiteResourceSchema>;
export type CreateSiteResourceResponse = SiteResource;
registry.registerPath({
method: "put",
path: "/org/{orgId}/site/{siteId}/resource",
description: "Create a new site resource.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
params: createSiteResourceParamsSchema,
body: {
content: {
"application/json": {
schema: createSiteResourceSchema
}
}
}
},
responses: {}
});
export async function createSiteResource(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = createSiteResourceParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedBody = createSiteResourceSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { siteId, orgId } = parsedParams.data;
const {
name,
protocol,
proxyPort,
destinationPort,
destinationIp,
enabled
} = parsedBody.data;
// Verify the site exists and belongs to the org
const [site] = await db
.select()
.from(sites)
.where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId)))
.limit(1);
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// check if resource with same protocol and proxy port already exists
const [existingResource] = await db
.select()
.from(siteResources)
.where(
and(
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId),
eq(siteResources.protocol, protocol),
eq(siteResources.proxyPort, proxyPort)
)
)
.limit(1);
if (existingResource && existingResource.siteResourceId) {
return next(
createHttpError(
HttpCode.CONFLICT,
"A resource with the same protocol and proxy port already exists"
)
);
}
// Create the site resource
const [newSiteResource] = await db
.insert(siteResources)
.values({
siteId,
orgId,
name,
protocol,
proxyPort,
destinationPort,
destinationIp,
enabled
})
.returning();
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) {
return next(createHttpError(HttpCode.NOT_FOUND, "Newt not found"));
}
await addTargets(newt.newtId, destinationIp, destinationPort, protocol);
logger.info(
`Created site resource ${newSiteResource.siteResourceId} for site ${siteId}`
);
return response(res, {
data: newSiteResource,
success: true,
error: false,
message: "Site resource created successfully",
status: HttpCode.CREATED
});
} catch (error) {
logger.error("Error creating site resource:", error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create site resource"
)
);
}
}

View File

@@ -0,0 +1,124 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, newts, sites } from "@server/db";
import { siteResources } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import { removeTargets } from "../client/targets";
const deleteSiteResourceParamsSchema = z
.object({
siteResourceId: z.string().transform(Number).pipe(z.number().int().positive()),
siteId: z.string().transform(Number).pipe(z.number().int().positive()),
orgId: z.string()
})
.strict();
export type DeleteSiteResourceResponse = {
message: string;
};
registry.registerPath({
method: "delete",
path: "/org/{orgId}/site/{siteId}/resource/{siteResourceId}",
description: "Delete a site resource.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
params: deleteSiteResourceParamsSchema
},
responses: {}
});
export async function deleteSiteResource(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = deleteSiteResourceParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { siteResourceId, siteId, orgId } = parsedParams.data;
const [site] = await db
.select()
.from(sites)
.where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId)))
.limit(1);
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// Check if site resource exists
const [existingSiteResource] = await db
.select()
.from(siteResources)
.where(and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
))
.limit(1);
if (!existingSiteResource) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Site resource not found"
)
);
}
// Delete the site resource
await db
.delete(siteResources)
.where(and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
));
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) {
return next(createHttpError(HttpCode.NOT_FOUND, "Newt not found"));
}
await removeTargets(
newt.newtId,
existingSiteResource.destinationIp,
existingSiteResource.destinationPort,
existingSiteResource.protocol
);
logger.info(`Deleted site resource ${siteResourceId} for site ${siteId}`);
return response(res, {
data: { message: "Site resource deleted successfully" },
success: true,
error: false,
message: "Site resource deleted successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error("Error deleting site resource:", error);
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to delete site resource"));
}
}

View File

@@ -0,0 +1,83 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { siteResources, SiteResource } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
const getSiteResourceParamsSchema = z
.object({
siteResourceId: z.string().transform(Number).pipe(z.number().int().positive()),
siteId: z.string().transform(Number).pipe(z.number().int().positive()),
orgId: z.string()
})
.strict();
export type GetSiteResourceResponse = SiteResource;
registry.registerPath({
method: "get",
path: "/org/{orgId}/site/{siteId}/resource/{siteResourceId}",
description: "Get a specific site resource.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
params: getSiteResourceParamsSchema
},
responses: {}
});
export async function getSiteResource(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = getSiteResourceParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { siteResourceId, siteId, orgId } = parsedParams.data;
// Get the site resource
const [siteResource] = await db
.select()
.from(siteResources)
.where(and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
))
.limit(1);
if (!siteResource) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Site resource not found"
)
);
}
return response(res, {
data: siteResource,
success: true,
error: false,
message: "Site resource retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error("Error getting site resource:", error);
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to get site resource"));
}
}

View File

@@ -0,0 +1,6 @@
export * from "./createSiteResource";
export * from "./deleteSiteResource";
export * from "./getSiteResource";
export * from "./updateSiteResource";
export * from "./listSiteResources";
export * from "./listAllSiteResourcesByOrg";

View File

@@ -0,0 +1,111 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { siteResources, sites, SiteResource } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
const listAllSiteResourcesByOrgParamsSchema = z
.object({
orgId: z.string()
})
.strict();
const listAllSiteResourcesByOrgQuerySchema = z.object({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.number().int().positive()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.number().int().nonnegative())
});
export type ListAllSiteResourcesByOrgResponse = {
siteResources: (SiteResource & { siteName: string, siteNiceId: string })[];
};
registry.registerPath({
method: "get",
path: "/org/{orgId}/site-resources",
description: "List all site resources for an organization.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
params: listAllSiteResourcesByOrgParamsSchema,
query: listAllSiteResourcesByOrgQuerySchema
},
responses: {}
});
export async function listAllSiteResourcesByOrg(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = listAllSiteResourcesByOrgParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedQuery = listAllSiteResourcesByOrgQuerySchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const { limit, offset } = parsedQuery.data;
// Get all site resources for the org with site names
const siteResourcesList = await db
.select({
siteResourceId: siteResources.siteResourceId,
siteId: siteResources.siteId,
orgId: siteResources.orgId,
name: siteResources.name,
protocol: siteResources.protocol,
proxyPort: siteResources.proxyPort,
destinationPort: siteResources.destinationPort,
destinationIp: siteResources.destinationIp,
enabled: siteResources.enabled,
siteName: sites.name,
siteNiceId: sites.niceId
})
.from(siteResources)
.innerJoin(sites, eq(siteResources.siteId, sites.siteId))
.where(eq(siteResources.orgId, orgId))
.limit(limit)
.offset(offset);
return response(res, {
data: { siteResources: siteResourcesList },
success: true,
error: false,
message: "Site resources retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error("Error listing all site resources by org:", error);
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to list site resources"));
}
}

View File

@@ -0,0 +1,118 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { siteResources, sites, SiteResource } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
const listSiteResourcesParamsSchema = z
.object({
siteId: z.string().transform(Number).pipe(z.number().int().positive()),
orgId: z.string()
})
.strict();
const listSiteResourcesQuerySchema = z.object({
limit: z
.string()
.optional()
.default("100")
.transform(Number)
.pipe(z.number().int().positive()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.number().int().nonnegative())
});
export type ListSiteResourcesResponse = {
siteResources: SiteResource[];
};
registry.registerPath({
method: "get",
path: "/org/{orgId}/site/{siteId}/resources",
description: "List site resources for a site.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
params: listSiteResourcesParamsSchema,
query: listSiteResourcesQuerySchema
},
responses: {}
});
export async function listSiteResources(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = listSiteResourcesParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedQuery = listSiteResourcesQuerySchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error).toString()
)
);
}
const { siteId, orgId } = parsedParams.data;
const { limit, offset } = parsedQuery.data;
// Verify the site exists and belongs to the org
const site = await db
.select()
.from(sites)
.where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId)))
.limit(1);
if (site.length === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Site not found"
)
);
}
// Get site resources
const siteResourcesList = await db
.select()
.from(siteResources)
.where(and(
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
))
.limit(limit)
.offset(offset);
return response(res, {
data: { siteResources: siteResourcesList },
success: true,
error: false,
message: "Site resources retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error("Error listing site resources:", error);
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to list site resources"));
}
}

View File

@@ -0,0 +1,196 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, newts, sites } from "@server/db";
import { siteResources, SiteResource } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import { addTargets } from "../client/targets";
const updateSiteResourceParamsSchema = z
.object({
siteResourceId: z
.string()
.transform(Number)
.pipe(z.number().int().positive()),
siteId: z.string().transform(Number).pipe(z.number().int().positive()),
orgId: z.string()
})
.strict();
const updateSiteResourceSchema = z
.object({
name: z.string().min(1).max(255).optional(),
protocol: z.enum(["tcp", "udp"]).optional(),
proxyPort: z.number().int().positive().optional(),
destinationPort: z.number().int().positive().optional(),
destinationIp: z.string().ip().optional(),
enabled: z.boolean().optional()
})
.strict();
export type UpdateSiteResourceBody = z.infer<typeof updateSiteResourceSchema>;
export type UpdateSiteResourceResponse = SiteResource;
registry.registerPath({
method: "post",
path: "/org/{orgId}/site/{siteId}/resource/{siteResourceId}",
description: "Update a site resource.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
params: updateSiteResourceParamsSchema,
body: {
content: {
"application/json": {
schema: updateSiteResourceSchema
}
}
}
},
responses: {}
});
export async function updateSiteResource(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = updateSiteResourceParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedBody = updateSiteResourceSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { siteResourceId, siteId, orgId } = parsedParams.data;
const updateData = parsedBody.data;
const [site] = await db
.select()
.from(sites)
.where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId)))
.limit(1);
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// Check if site resource exists
const [existingSiteResource] = await db
.select()
.from(siteResources)
.where(
and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
)
)
.limit(1);
if (!existingSiteResource) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Site resource not found")
);
}
const protocol = updateData.protocol || existingSiteResource.protocol;
const proxyPort =
updateData.proxyPort || existingSiteResource.proxyPort;
// check if resource with same protocol and proxy port already exists
const [existingResource] = await db
.select()
.from(siteResources)
.where(
and(
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId),
eq(siteResources.protocol, protocol),
eq(siteResources.proxyPort, proxyPort)
)
)
.limit(1);
if (
existingResource &&
existingResource.siteResourceId !== siteResourceId
) {
return next(
createHttpError(
HttpCode.CONFLICT,
"A resource with the same protocol and proxy port already exists"
)
);
}
// Update the site resource
const [updatedSiteResource] = await db
.update(siteResources)
.set(updateData)
.where(
and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
)
)
.returning();
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) {
return next(createHttpError(HttpCode.NOT_FOUND, "Newt not found"));
}
await addTargets(
newt.newtId,
updatedSiteResource.destinationIp,
updatedSiteResource.destinationPort,
updatedSiteResource.protocol
);
logger.info(
`Updated site resource ${siteResourceId} for site ${siteId}`
);
return response(res, {
data: updatedSiteResource,
success: true,
error: false,
message: "Site resource updated successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error("Error updating site resource:", error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to update site resource"
)
);
}
}

View File

@@ -26,6 +26,7 @@ const createTargetParamsSchema = z
const createTargetSchema = z
.object({
siteId: z.number().int().positive(),
ip: z.string().refine(isTargetValid),
method: z.string().optional().nullable(),
port: z.number().int().min(1).max(65535),
@@ -98,17 +99,41 @@ export async function createTarget(
);
}
const siteId = targetData.siteId;
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, resource.siteId!))
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${resource.siteId} not found`
`Site with ID ${siteId} not found`
)
);
}
const existingTargets = await db
.select()
.from(targets)
.where(eq(targets.resourceId, resourceId));
const existingTarget = existingTargets.find(
(target) =>
target.ip === targetData.ip &&
target.port === targetData.port &&
target.method === targetData.method &&
target.siteId === targetData.siteId
);
if (existingTarget) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Target with IP ${targetData.ip}, port ${targetData.port}, method ${targetData.method} already exists for resource ID ${resourceId}`
)
);
}
@@ -173,7 +198,12 @@ export async function createTarget(
.where(eq(newts.siteId, site.siteId))
.limit(1);
addTargets(newt.newtId, newTarget, resource.protocol, resource.proxyPort);
await addTargets(
newt.newtId,
newTarget,
resource.protocol,
resource.proxyPort
);
}
}
}

View File

@@ -76,38 +76,38 @@ export async function deleteTarget(
);
}
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, resource.siteId!))
.limit(1);
if (!site) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${resource.siteId} not found`
)
);
}
if (site.pubKey) {
if (site.type == "wireguard") {
await addPeer(site.exitNodeId!, {
publicKey: site.pubKey,
allowedIps: await getAllowedIps(site.siteId)
});
} else if (site.type == "newt") {
// get the newt on the site by querying the newt table for siteId
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
removeTargets(newt.newtId, [deletedTarget], resource.protocol, resource.proxyPort);
}
}
// const [site] = await db
// .select()
// .from(sites)
// .where(eq(sites.siteId, resource.siteId!))
// .limit(1);
//
// if (!site) {
// return next(
// createHttpError(
// HttpCode.NOT_FOUND,
// `Site with ID ${resource.siteId} not found`
// )
// );
// }
//
// if (site.pubKey) {
// if (site.type == "wireguard") {
// await addPeer(site.exitNodeId!, {
// publicKey: site.pubKey,
// allowedIps: await getAllowedIps(site.siteId)
// });
// } else if (site.type == "newt") {
// // get the newt on the site by querying the newt table for siteId
// const [newt] = await db
// .select()
// .from(newts)
// .where(eq(newts.siteId, site.siteId))
// .limit(1);
//
// removeTargets(newt.newtId, [deletedTarget], resource.protocol, resource.proxyPort);
// }
// }
return response(res, {
data: null,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { db, Target } from "@server/db";
import { targets } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
@@ -16,6 +16,8 @@ const getTargetSchema = z
})
.strict();
type GetTargetResponse = Target;
registry.registerPath({
method: "get",
path: "/target/{targetId}",
@@ -60,7 +62,7 @@ export async function getTarget(
);
}
return response(res, {
return response<GetTargetResponse>(res, {
data: target[0],
success: true,
error: false,

View File

@@ -8,29 +8,21 @@ export async function pickPort(siteId: number): Promise<{
internalPort: number;
targetIps: string[];
}> {
const resourcesRes = await db
.select()
.from(resources)
.where(eq(resources.siteId, siteId));
// TODO: is this all inefficient?
// Fetch targets for all resources of this site
const targetIps: string[] = [];
const targetInternalPorts: number[] = [];
await Promise.all(
resourcesRes.map(async (resource) => {
const targetsRes = await db
.select()
.from(targets)
.where(eq(targets.resourceId, resource.resourceId));
targetsRes.forEach((target) => {
targetIps.push(`${target.ip}/32`);
if (target.internalPort) {
targetInternalPorts.push(target.internalPort);
}
});
})
);
const targetsRes = await db
.select()
.from(targets)
.where(eq(targets.siteId, siteId));
targetsRes.forEach((target) => {
targetIps.push(`${target.ip}/32`);
if (target.internalPort) {
targetInternalPorts.push(target.internalPort);
}
});
let internalPort!: number;
// pick a port random port from 40000 to 65535 that is not in use
@@ -43,28 +35,20 @@ export async function pickPort(siteId: number): Promise<{
break;
}
}
currentBannedPorts.push(internalPort);
return { internalPort, targetIps };
}
export async function getAllowedIps(siteId: number) {
// TODO: is this all inefficient?
const resourcesRes = await db
.select()
.from(resources)
.where(eq(resources.siteId, siteId));
// Fetch targets for all resources of this site
const targetIps = await Promise.all(
resourcesRes.map(async (resource) => {
const targetsRes = await db
.select()
.from(targets)
.where(eq(targets.resourceId, resource.resourceId));
return targetsRes.map((target) => `${target.ip}/32`);
})
);
const targetsRes = await db
.select()
.from(targets)
.where(eq(targets.siteId, siteId));
const targetIps = targetsRes.map((target) => `${target.ip}/32`);
return targetIps.flat();
}

View File

@@ -2,4 +2,4 @@ export * from "./getTarget";
export * from "./createTarget";
export * from "./deleteTarget";
export * from "./updateTarget";
export * from "./listTargets";
export * from "./listTargets";

View File

@@ -1,4 +1,4 @@
import { db } from "@server/db";
import { db, sites } from "@server/db";
import { targets } from "@server/db";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
@@ -42,11 +42,12 @@ function queryTargets(resourceId: number) {
method: targets.method,
port: targets.port,
enabled: targets.enabled,
resourceId: targets.resourceId
// resourceName: resources.name,
resourceId: targets.resourceId,
siteId: targets.siteId,
siteType: sites.type
})
.from(targets)
// .leftJoin(resources, eq(targets.resourceId, resources.resourceId))
.leftJoin(sites, eq(sites.siteId, targets.siteId))
.where(eq(targets.resourceId, resourceId));
return baseQuery;

View File

@@ -22,6 +22,7 @@ const updateTargetParamsSchema = z
const updateTargetBodySchema = z
.object({
siteId: z.number().int().positive(),
ip: z.string().refine(isTargetValid),
method: z.string().min(1).max(10).optional().nullable(),
port: z.number().int().min(1).max(65535).optional(),
@@ -77,6 +78,7 @@ export async function updateTarget(
}
const { targetId } = parsedParams.data;
const { siteId } = parsedBody.data;
const [target] = await db
.select()
@@ -111,14 +113,42 @@ export async function updateTarget(
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, resource.siteId!))
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${resource.siteId} not found`
`Site with ID ${siteId} not found`
)
);
}
const targetData = {
...target,
...parsedBody.data
};
const existingTargets = await db
.select()
.from(targets)
.where(eq(targets.resourceId, target.resourceId));
const foundTarget = existingTargets.find(
(target) =>
target.targetId !== targetId && // Exclude the current target being updated
target.ip === targetData.ip &&
target.port === targetData.port &&
target.method === targetData.method &&
target.siteId === targetData.siteId
);
if (foundTarget) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Target with IP ${targetData.ip}, port ${targetData.port}, and method ${targetData.method} already exists on the same site.`
)
);
}
@@ -157,7 +187,12 @@ export async function updateTarget(
.where(eq(newts.siteId, site.siteId))
.limit(1);
addTargets(newt.newtId, [updatedTarget], resource.protocol, resource.proxyPort);
await addTargets(
newt.newtId,
[updatedTarget],
resource.protocol,
resource.proxyPort
);
}
}
return response(res, {

View File

@@ -1,11 +1,21 @@
import { Request, Response } from "express";
import { db, exitNodes } from "@server/db";
import { and, eq, inArray, or, isNull } from "drizzle-orm";
import { and, eq, inArray, or, isNull, ne } from "drizzle-orm";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
import config from "@server/lib/config";
import { orgs, resources, sites, Target, targets } from "@server/db";
// Extended Target interface that includes site information
interface TargetWithSite extends Target {
site: {
siteId: number;
type: string;
subnet: string | null;
exitNodeId: number | null;
};
}
let currentExitNodeId: number;
export async function traefikConfigProvider(
@@ -44,8 +54,9 @@ export async function traefikConfigProvider(
}
}
// Get the site(s) on this exit node
const resourcesWithRelations = await tx
// Get resources with their targets and sites in a single optimized query
// Start from sites on this exit node, then join to targets and resources
const resourcesWithTargetsAndSites = await tx
.select({
// Resource fields
resourceId: resources.resourceId,
@@ -56,67 +67,82 @@ export async function traefikConfigProvider(
protocol: resources.protocol,
subdomain: resources.subdomain,
domainId: resources.domainId,
// Site fields
site: {
siteId: sites.siteId,
type: sites.type,
subnet: sites.subnet,
exitNodeId: sites.exitNodeId
},
enabled: resources.enabled,
stickySession: resources.stickySession,
tlsServerName: resources.tlsServerName,
setHostHeader: resources.setHostHeader,
enableProxy: resources.enableProxy
enableProxy: resources.enableProxy,
// Target fields
targetId: targets.targetId,
targetEnabled: targets.enabled,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
// Site fields
siteId: sites.siteId,
siteType: sites.type,
subnet: sites.subnet,
exitNodeId: sites.exitNodeId
})
.from(resources)
.innerJoin(sites, eq(sites.siteId, resources.siteId))
.from(sites)
.innerJoin(targets, eq(targets.siteId, sites.siteId))
.innerJoin(resources, eq(resources.resourceId, targets.resourceId))
.where(
or(
eq(sites.exitNodeId, currentExitNodeId),
isNull(sites.exitNodeId)
and(
eq(targets.enabled, true),
eq(resources.enabled, true),
or(
eq(sites.exitNodeId, currentExitNodeId),
isNull(sites.exitNodeId)
)
)
);
// Get all resource IDs from the first query
const resourceIds = resourcesWithRelations.map((r) => r.resourceId);
// Group by resource and include targets with their unique site data
const resourcesMap = new Map();
// Second query to get all enabled targets for these resources
const allTargets =
resourceIds.length > 0
? await tx
.select({
resourceId: targets.resourceId,
targetId: targets.targetId,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled
})
.from(targets)
.where(
and(
inArray(targets.resourceId, resourceIds),
eq(targets.enabled, true)
)
)
: [];
resourcesWithTargetsAndSites.forEach((row) => {
const resourceId = row.resourceId;
// Create a map for fast target lookup by resourceId
const targetsMap = allTargets.reduce((map, target) => {
if (!map.has(target.resourceId)) {
map.set(target.resourceId, []);
if (!resourcesMap.has(resourceId)) {
resourcesMap.set(resourceId, {
resourceId: row.resourceId,
fullDomain: row.fullDomain,
ssl: row.ssl,
http: row.http,
proxyPort: row.proxyPort,
protocol: row.protocol,
subdomain: row.subdomain,
domainId: row.domainId,
enabled: row.enabled,
stickySession: row.stickySession,
tlsServerName: row.tlsServerName,
setHostHeader: row.setHostHeader,
enableProxy: row.enableProxy,
targets: []
});
}
map.get(target.resourceId).push(target);
return map;
}, new Map());
// Combine the data
return resourcesWithRelations.map((resource) => ({
...resource,
targets: targetsMap.get(resource.resourceId) || []
}));
// Add target with its associated site data
resourcesMap.get(resourceId).targets.push({
resourceId: row.resourceId,
targetId: row.targetId,
ip: row.ip,
method: row.method,
port: row.port,
internalPort: row.internalPort,
enabled: row.targetEnabled,
site: {
siteId: row.siteId,
type: row.siteType,
subnet: row.subnet,
exitNodeId: row.exitNodeId
}
});
});
return Array.from(resourcesMap.values());
});
if (!allResources.length) {
@@ -167,8 +193,7 @@ export async function traefikConfigProvider(
};
for (const resource of allResources) {
const targets = resource.targets as Target[];
const site = resource.site;
const targets = resource.targets as TargetWithSite[];
const routerName = `${resource.resourceId}-router`;
const serviceName = `${resource.resourceId}-service`;
@@ -272,13 +297,13 @@ export async function traefikConfigProvider(
config_output.http.services![serviceName] = {
loadBalancer: {
servers: targets
.filter((target: Target) => {
.filter((target: TargetWithSite) => {
if (!target.enabled) {
return false;
}
if (
site.type === "local" ||
site.type === "wireguard"
target.site.type === "local" ||
target.site.type === "wireguard"
) {
if (
!target.ip ||
@@ -287,27 +312,27 @@ export async function traefikConfigProvider(
) {
return false;
}
} else if (site.type === "newt") {
} else if (target.site.type === "newt") {
if (
!target.internalPort ||
!target.method ||
!site.subnet
!target.site.subnet
) {
return false;
}
}
return true;
})
.map((target: Target) => {
.map((target: TargetWithSite) => {
if (
site.type === "local" ||
site.type === "wireguard"
target.site.type === "local" ||
target.site.type === "wireguard"
) {
return {
url: `${target.method}://${target.ip}:${target.port}`
};
} else if (site.type === "newt") {
const ip = site.subnet!.split("/")[0];
} else if (target.site.type === "newt") {
const ip = target.site.subnet!.split("/")[0];
return {
url: `${target.method}://${ip}:${target.internalPort}`
};
@@ -393,34 +418,34 @@ export async function traefikConfigProvider(
config_output[protocol].services[serviceName] = {
loadBalancer: {
servers: targets
.filter((target: Target) => {
.filter((target: TargetWithSite) => {
if (!target.enabled) {
return false;
}
if (
site.type === "local" ||
site.type === "wireguard"
target.site.type === "local" ||
target.site.type === "wireguard"
) {
if (!target.ip || !target.port) {
return false;
}
} else if (site.type === "newt") {
if (!target.internalPort || !site.subnet) {
} else if (target.site.type === "newt") {
if (!target.internalPort || !target.site.subnet) {
return false;
}
}
return true;
})
.map((target: Target) => {
.map((target: TargetWithSite) => {
if (
site.type === "local" ||
site.type === "wireguard"
target.site.type === "local" ||
target.site.type === "wireguard"
) {
return {
address: `${target.ip}:${target.port}`
};
} else if (site.type === "newt") {
const ip = site.subnet!.split("/")[0];
} else if (target.site.type === "newt") {
const ip = target.site.subnet!.split("/")[0];
return {
address: `${ip}:${target.internalPort}`
};

View File

@@ -43,17 +43,17 @@ export async function addUserSite(
})
.returning();
const siteResources = await trx
.select()
.from(resources)
.where(eq(resources.siteId, siteId));
for (const resource of siteResources) {
await trx.insert(userResources).values({
userId,
resourceId: resource.resourceId
});
}
// const siteResources = await trx
// .select()
// .from(resources)
// .where(eq(resources.siteId, siteId));
//
// for (const resource of siteResources) {
// await trx.insert(userResources).values({
// userId,
// resourceId: resource.resourceId
// });
// }
return response(res, {
data: newUserSite[0],

View File

@@ -71,22 +71,22 @@ export async function removeUserSite(
);
}
const siteResources = await trx
.select()
.from(resources)
.where(eq(resources.siteId, siteId));
for (const resource of siteResources) {
await trx
.delete(userResources)
.where(
and(
eq(userResources.userId, userId),
eq(userResources.resourceId, resource.resourceId)
)
)
.returning();
}
// const siteResources = await trx
// .select()
// .from(resources)
// .where(eq(resources.siteId, siteId));
//
// for (const resource of siteResources) {
// await trx
// .delete(userResources)
// .where(
// and(
// eq(userResources.userId, userId),
// eq(userResources.resourceId, resource.resourceId)
// )
// )
// .returning();
// }
});
return response(res, {

View File

@@ -23,7 +23,7 @@ export const messageHandlers: Record<string, MessageHandler> = {
"olm/ping": handleOlmPingMessage,
"newt/socket/status": handleDockerStatusMessage,
"newt/socket/containers": handleDockerContainersMessage,
"newt/ping/request": handleNewtPingRequestMessage,
"newt/ping/request": handleNewtPingRequestMessage
};
startOfflineChecker(); // this is to handle the offline check for olms

View File

@@ -22,4 +22,4 @@ export default async function migration() {
console.log("Unable to add setupTokens table:", e);
throw e;
}
}
}

View File

@@ -32,4 +32,4 @@ export default async function migration() {
console.log("Unable to add setupTokens table:", e);
throw e;
}
}
}