mirror of
https://github.com/fosrl/pangolin.git
synced 2026-02-23 21:36:37 +00:00
Chungus 2.0
This commit is contained in:
97
server/private/routers/auth/getSessionTransferToken.ts
Normal file
97
server/private/routers/auth/getSessionTransferToken.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, sessionTransferToken } from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import {
|
||||
generateSessionToken,
|
||||
SESSION_COOKIE_NAME
|
||||
} from "@server/auth/sessions/app";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { response } from "@server/lib/response";
|
||||
import { encrypt } from "@server/lib/crypto";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
const paramsSchema = z.object({}).strict();
|
||||
|
||||
export type GetSessionTransferTokenRenponse = {
|
||||
token: string;
|
||||
};
|
||||
|
||||
export async function getSessionTransferToken(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { user, session } = req;
|
||||
|
||||
if (!user || !session) {
|
||||
return next(createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized"));
|
||||
}
|
||||
|
||||
const tokenRaw = generateSessionToken();
|
||||
const token = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(tokenRaw))
|
||||
);
|
||||
|
||||
const rawSessionId = req.cookies[SESSION_COOKIE_NAME];
|
||||
|
||||
if (!rawSessionId) {
|
||||
return next(createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized"));
|
||||
}
|
||||
|
||||
const encryptedSession = encrypt(
|
||||
rawSessionId,
|
||||
config.getRawConfig().server.secret!
|
||||
);
|
||||
|
||||
await db.insert(sessionTransferToken).values({
|
||||
encryptedSession,
|
||||
token,
|
||||
sessionId: session.sessionId,
|
||||
expiresAt: Date.now() + 30 * 1000 // Token valid for 30 seconds
|
||||
});
|
||||
|
||||
return response<GetSessionTransferTokenRenponse>(res, {
|
||||
data: {
|
||||
token: tokenRaw
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Transfer token created successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
16
server/private/routers/auth/index.ts
Normal file
16
server/private/routers/auth/index.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./transferSession";
|
||||
export * from "./getSessionTransferToken";
|
||||
export * from "./quickStart";
|
||||
579
server/private/routers/auth/quickStart.ts
Normal file
579
server/private/routers/auth/quickStart.ts
Normal file
@@ -0,0 +1,579 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import {
|
||||
account,
|
||||
db,
|
||||
domainNamespaces,
|
||||
domains,
|
||||
exitNodes,
|
||||
newts,
|
||||
newtSessions,
|
||||
orgs,
|
||||
passwordResetTokens,
|
||||
Resource,
|
||||
resourcePassword,
|
||||
resourcePincode,
|
||||
resources,
|
||||
resourceWhitelist,
|
||||
roleResources,
|
||||
roles,
|
||||
roleSites,
|
||||
sites,
|
||||
targetHealthCheck,
|
||||
targets,
|
||||
userResources,
|
||||
userSites
|
||||
} from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { z } from "zod";
|
||||
import { users } from "@server/db";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
import { SqliteError } from "better-sqlite3";
|
||||
import { eq, and, sql } from "drizzle-orm";
|
||||
import moment from "moment";
|
||||
import { generateId } from "@server/auth/sessions/app";
|
||||
import config from "@server/lib/config";
|
||||
import logger from "@server/logger";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
import { createUserAccountOrg } from "@server/lib/createUserAccountOrg";
|
||||
import { sendEmail } from "@server/emails";
|
||||
import WelcomeQuickStart from "@server/emails/templates/WelcomeQuickStart";
|
||||
import { alphabet, generateRandomString } from "oslo/crypto";
|
||||
import { createDate, TimeSpan } from "oslo";
|
||||
import { getUniqueResourceName, getUniqueSiteName } from "@server/db/names";
|
||||
import { pickPort } from "@server/routers/target/helpers";
|
||||
import { addTargets } from "@server/routers/newt/targets";
|
||||
import { isTargetValid } from "@server/lib/validators";
|
||||
import { listExitNodes } from "#private/lib/exitNodes";
|
||||
|
||||
const bodySchema = z.object({
|
||||
email: z.string().toLowerCase().email(),
|
||||
ip: z.string().refine(isTargetValid),
|
||||
method: z.enum(["http", "https"]),
|
||||
port: z.number().int().min(1).max(65535),
|
||||
pincode: z
|
||||
.string()
|
||||
.regex(/^\d{6}$/)
|
||||
.optional(),
|
||||
password: z.string().min(4).max(100).optional(),
|
||||
enableWhitelist: z.boolean().optional().default(true),
|
||||
animalId: z.string() // This is actually the secret key for the backend
|
||||
});
|
||||
|
||||
export type QuickStartBody = z.infer<typeof bodySchema>;
|
||||
|
||||
export type QuickStartResponse = {
|
||||
newtId: string;
|
||||
newtSecret: string;
|
||||
resourceUrl: string;
|
||||
completeSignUpLink: string;
|
||||
};
|
||||
|
||||
const DEMO_UBO_KEY = "b460293f-347c-4b30-837d-4e06a04d5a22";
|
||||
|
||||
export async function quickStart(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const {
|
||||
email,
|
||||
ip,
|
||||
method,
|
||||
port,
|
||||
pincode,
|
||||
password,
|
||||
enableWhitelist,
|
||||
animalId
|
||||
} = parsedBody.data;
|
||||
|
||||
try {
|
||||
const tokenValidation = validateTokenOnApi(animalId);
|
||||
|
||||
if (!tokenValidation.isValid) {
|
||||
logger.warn(
|
||||
`Quick start failed for ${email} token ${animalId}: ${tokenValidation.message}`
|
||||
);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Invalid or expired token"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (animalId === DEMO_UBO_KEY) {
|
||||
if (email !== "mehrdad@getubo.com") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Invalid email for demo Ubo key"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [existing] = await db
|
||||
.select()
|
||||
.from(users)
|
||||
.where(
|
||||
and(
|
||||
eq(users.email, email),
|
||||
eq(users.type, UserType.Internal)
|
||||
)
|
||||
);
|
||||
|
||||
if (existing) {
|
||||
// delete the user if it already exists
|
||||
await db.delete(users).where(eq(users.userId, existing.userId));
|
||||
const orgId = `org_${existing.userId}`;
|
||||
await db.delete(orgs).where(eq(orgs.orgId, orgId));
|
||||
}
|
||||
}
|
||||
|
||||
const tempPassword = generateId(15);
|
||||
const passwordHash = await hashPassword(tempPassword);
|
||||
const userId = generateId(15);
|
||||
|
||||
// TODO: see if that user already exists?
|
||||
|
||||
// Create the sandbox user
|
||||
const existing = await db
|
||||
.select()
|
||||
.from(users)
|
||||
.where(
|
||||
and(eq(users.email, email), eq(users.type, UserType.Internal))
|
||||
);
|
||||
|
||||
if (existing && existing.length > 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"A user with that email address already exists"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let newtId: string;
|
||||
let secret: string;
|
||||
let fullDomain: string;
|
||||
let resource: Resource;
|
||||
let completeSignUpLink: string;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
await trx.insert(users).values({
|
||||
userId: userId,
|
||||
type: UserType.Internal,
|
||||
username: email,
|
||||
email: email,
|
||||
passwordHash,
|
||||
dateCreated: moment().toISOString()
|
||||
});
|
||||
|
||||
// create user"s account
|
||||
await trx.insert(account).values({
|
||||
userId
|
||||
});
|
||||
});
|
||||
|
||||
const { success, error, org } = await createUserAccountOrg(
|
||||
userId,
|
||||
email
|
||||
);
|
||||
if (!success) {
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
throw new Error("Failed to create user account and organization");
|
||||
}
|
||||
if (!org) {
|
||||
throw new Error("Failed to create user account and organization");
|
||||
}
|
||||
|
||||
const orgId = org.orgId;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
const token = generateRandomString(
|
||||
8,
|
||||
alphabet("0-9", "A-Z", "a-z")
|
||||
);
|
||||
|
||||
await trx
|
||||
.delete(passwordResetTokens)
|
||||
.where(eq(passwordResetTokens.userId, userId));
|
||||
|
||||
const tokenHash = await hashPassword(token);
|
||||
|
||||
await trx.insert(passwordResetTokens).values({
|
||||
userId: userId,
|
||||
email: email,
|
||||
tokenHash,
|
||||
expiresAt: createDate(new TimeSpan(7, "d")).getTime()
|
||||
});
|
||||
|
||||
// // Create the sandbox newt
|
||||
// const newClientAddress = await getNextAvailableClientSubnet(orgId);
|
||||
// if (!newClientAddress) {
|
||||
// throw new Error("No available subnet found");
|
||||
// }
|
||||
|
||||
// const clientAddress = newClientAddress.split("/")[0];
|
||||
|
||||
newtId = generateId(15);
|
||||
secret = generateId(48);
|
||||
|
||||
// Create the sandbox site
|
||||
const siteNiceId = await getUniqueSiteName(orgId);
|
||||
const siteName = `First Site`;
|
||||
|
||||
// pick a random exit node
|
||||
const exitNodesList = await listExitNodes(orgId);
|
||||
|
||||
// select a random exit node
|
||||
const randomExitNode =
|
||||
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
|
||||
|
||||
if (!randomExitNode) {
|
||||
throw new Error("No exit nodes available");
|
||||
}
|
||||
|
||||
const [newSite] = await trx
|
||||
.insert(sites)
|
||||
.values({
|
||||
orgId,
|
||||
exitNodeId: randomExitNode.exitNodeId,
|
||||
name: siteName,
|
||||
niceId: siteNiceId,
|
||||
// address: clientAddress,
|
||||
type: "newt",
|
||||
dockerSocketEnabled: true
|
||||
})
|
||||
.returning();
|
||||
|
||||
const siteId = newSite.siteId;
|
||||
|
||||
const adminRole = await trx
|
||||
.select()
|
||||
.from(roles)
|
||||
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
|
||||
.limit(1);
|
||||
|
||||
if (adminRole.length === 0) {
|
||||
throw new Error("Admin role not found");
|
||||
}
|
||||
|
||||
await trx.insert(roleSites).values({
|
||||
roleId: adminRole[0].roleId,
|
||||
siteId: newSite.siteId
|
||||
});
|
||||
|
||||
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
|
||||
// make sure the user can access the site
|
||||
await trx.insert(userSites).values({
|
||||
userId: req.user?.userId!,
|
||||
siteId: newSite.siteId
|
||||
});
|
||||
}
|
||||
|
||||
// add the peer to the exit node
|
||||
const secretHash = await hashPassword(secret!);
|
||||
|
||||
await trx.insert(newts).values({
|
||||
newtId: newtId!,
|
||||
secretHash,
|
||||
siteId: newSite.siteId,
|
||||
dateCreated: moment().toISOString()
|
||||
});
|
||||
|
||||
const [randomNamespace] = await trx
|
||||
.select()
|
||||
.from(domainNamespaces)
|
||||
.orderBy(sql`RANDOM()`)
|
||||
.limit(1);
|
||||
|
||||
if (!randomNamespace) {
|
||||
throw new Error("No domain namespace available");
|
||||
}
|
||||
|
||||
const [randomNamespaceDomain] = await trx
|
||||
.select()
|
||||
.from(domains)
|
||||
.where(eq(domains.domainId, randomNamespace.domainId))
|
||||
.limit(1);
|
||||
|
||||
if (!randomNamespaceDomain) {
|
||||
throw new Error("No domain found for the namespace");
|
||||
}
|
||||
|
||||
const resourceNiceId = await getUniqueResourceName(orgId);
|
||||
|
||||
// Create sandbox resource
|
||||
const subdomain = `${resourceNiceId}-${generateId(5)}`;
|
||||
fullDomain = `${subdomain}.${randomNamespaceDomain.baseDomain}`;
|
||||
|
||||
const resourceName = `First Resource`;
|
||||
|
||||
const newResource = await trx
|
||||
.insert(resources)
|
||||
.values({
|
||||
niceId: resourceNiceId,
|
||||
fullDomain,
|
||||
domainId: randomNamespaceDomain.domainId,
|
||||
orgId,
|
||||
name: resourceName,
|
||||
subdomain,
|
||||
http: true,
|
||||
protocol: "tcp",
|
||||
ssl: true,
|
||||
sso: false,
|
||||
emailWhitelistEnabled: enableWhitelist
|
||||
})
|
||||
.returning();
|
||||
|
||||
await trx.insert(roleResources).values({
|
||||
roleId: adminRole[0].roleId,
|
||||
resourceId: newResource[0].resourceId
|
||||
});
|
||||
|
||||
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
|
||||
// make sure the user can access the resource
|
||||
await trx.insert(userResources).values({
|
||||
userId: req.user?.userId!,
|
||||
resourceId: newResource[0].resourceId
|
||||
});
|
||||
}
|
||||
|
||||
resource = newResource[0];
|
||||
|
||||
// Create the sandbox target
|
||||
const { internalPort, targetIps } = await pickPort(siteId!, trx);
|
||||
|
||||
if (!internalPort) {
|
||||
throw new Error("No available internal port");
|
||||
}
|
||||
|
||||
const newTarget = await trx
|
||||
.insert(targets)
|
||||
.values({
|
||||
resourceId: resource.resourceId,
|
||||
siteId: siteId!,
|
||||
internalPort,
|
||||
ip,
|
||||
method,
|
||||
port,
|
||||
enabled: true
|
||||
})
|
||||
.returning();
|
||||
|
||||
const newHealthcheck = await trx
|
||||
.insert(targetHealthCheck)
|
||||
.values({
|
||||
targetId: newTarget[0].targetId,
|
||||
hcEnabled: false
|
||||
}).returning();
|
||||
|
||||
// add the new target to the targetIps array
|
||||
targetIps.push(`${ip}/32`);
|
||||
|
||||
const [newt] = await trx
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, siteId!))
|
||||
.limit(1);
|
||||
|
||||
await addTargets(newt.newtId, newTarget, newHealthcheck, resource.protocol);
|
||||
|
||||
// Set resource pincode if provided
|
||||
if (pincode) {
|
||||
await trx
|
||||
.delete(resourcePincode)
|
||||
.where(
|
||||
eq(resourcePincode.resourceId, resource!.resourceId)
|
||||
);
|
||||
|
||||
const pincodeHash = await hashPassword(pincode);
|
||||
|
||||
await trx.insert(resourcePincode).values({
|
||||
resourceId: resource!.resourceId,
|
||||
pincodeHash,
|
||||
digitLength: 6
|
||||
});
|
||||
}
|
||||
|
||||
// Set resource password if provided
|
||||
if (password) {
|
||||
await trx
|
||||
.delete(resourcePassword)
|
||||
.where(
|
||||
eq(resourcePassword.resourceId, resource!.resourceId)
|
||||
);
|
||||
|
||||
const passwordHash = await hashPassword(password);
|
||||
|
||||
await trx.insert(resourcePassword).values({
|
||||
resourceId: resource!.resourceId,
|
||||
passwordHash
|
||||
});
|
||||
}
|
||||
|
||||
// Set resource OTP if whitelist is enabled
|
||||
if (enableWhitelist) {
|
||||
await trx.insert(resourceWhitelist).values({
|
||||
email,
|
||||
resourceId: resource!.resourceId
|
||||
});
|
||||
}
|
||||
|
||||
completeSignUpLink = `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}&token=${token}`;
|
||||
|
||||
// Store token for email outside transaction
|
||||
await sendEmail(
|
||||
WelcomeQuickStart({
|
||||
username: email,
|
||||
link: completeSignUpLink,
|
||||
fallbackLink: `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}`,
|
||||
resourceMethod: method,
|
||||
resourceHostname: ip,
|
||||
resourcePort: port,
|
||||
resourceUrl: `https://${fullDomain}`,
|
||||
cliCommand: `newt --id ${newtId} --secret ${secret}`
|
||||
}),
|
||||
{
|
||||
to: email,
|
||||
from: config.getNoReplyEmail(),
|
||||
subject: `Access your Pangolin dashboard and resources`
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
return response<QuickStartResponse>(res, {
|
||||
data: {
|
||||
newtId: newtId!,
|
||||
newtSecret: secret!,
|
||||
resourceUrl: `https://${fullDomain!}`,
|
||||
completeSignUpLink: completeSignUpLink!
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Quick start completed successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
`Account already exists with that email. Email: ${email}. IP: ${req.ip}.`
|
||||
);
|
||||
}
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"A user with that email address already exists"
|
||||
)
|
||||
);
|
||||
} else {
|
||||
logger.error(e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to do quick start"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const BACKEND_SECRET_KEY = "4f9b6000-5d1a-11f0-9de7-ff2cc032f501";
|
||||
|
||||
/**
|
||||
* Validates a token received from the frontend.
|
||||
* @param {string} token The validation token from the request.
|
||||
* @returns {{ isValid: boolean; message: string }} An object indicating if the token is valid.
|
||||
*/
|
||||
const validateTokenOnApi = (
|
||||
token: string
|
||||
): { isValid: boolean; message: string } => {
|
||||
if (token === DEMO_UBO_KEY) {
|
||||
// Special case for demo UBO key
|
||||
return { isValid: true, message: "Demo UBO key is valid." };
|
||||
}
|
||||
|
||||
if (!token) {
|
||||
return { isValid: false, message: "Error: No token provided." };
|
||||
}
|
||||
|
||||
try {
|
||||
// 1. Decode the base64 string
|
||||
const decodedB64 = atob(token);
|
||||
|
||||
// 2. Reverse the character code manipulation
|
||||
const deobfuscated = decodedB64
|
||||
.split("")
|
||||
.map((char) => String.fromCharCode(char.charCodeAt(0) - 5)) // Reverse the shift
|
||||
.join("");
|
||||
|
||||
// 3. Split the data to get the original secret and timestamp
|
||||
const parts = deobfuscated.split("|");
|
||||
if (parts.length !== 2) {
|
||||
throw new Error("Invalid token format.");
|
||||
}
|
||||
const receivedKey = parts[0];
|
||||
const tokenTimestamp = parseInt(parts[1], 10);
|
||||
|
||||
// 4. Check if the secret key matches
|
||||
if (receivedKey !== BACKEND_SECRET_KEY) {
|
||||
return { isValid: false, message: "Invalid token: Key mismatch." };
|
||||
}
|
||||
|
||||
// 5. Check if the timestamp is recent (e.g., within 30 seconds) to prevent replay attacks
|
||||
const now = Date.now();
|
||||
const timeDifference = now - tokenTimestamp;
|
||||
|
||||
if (timeDifference > 30000) {
|
||||
// 30 seconds
|
||||
return { isValid: false, message: "Invalid token: Expired." };
|
||||
}
|
||||
|
||||
if (timeDifference < 0) {
|
||||
// Timestamp is in the future
|
||||
return {
|
||||
isValid: false,
|
||||
message: "Invalid token: Timestamp is in the future."
|
||||
};
|
||||
}
|
||||
|
||||
// If all checks pass, the token is valid
|
||||
return { isValid: true, message: "Token is valid!" };
|
||||
} catch (error) {
|
||||
// This will catch errors from atob (if not valid base64) or other issues.
|
||||
return {
|
||||
isValid: false,
|
||||
message: `Error: ${(error as Error).message}`
|
||||
};
|
||||
}
|
||||
};
|
||||
128
server/private/routers/auth/transferSession.ts
Normal file
128
server/private/routers/auth/transferSession.ts
Normal file
@@ -0,0 +1,128 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { sessions, sessionTransferToken } from "@server/db";
|
||||
import { db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { response } from "@server/lib/response";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { serializeSessionCookie } from "@server/auth/sessions/app";
|
||||
import { decrypt } from "@server/lib/crypto";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
const bodySchema = z.object({
|
||||
token: z.string()
|
||||
});
|
||||
|
||||
export type TransferSessionBodySchema = z.infer<typeof bodySchema>;
|
||||
|
||||
export type TransferSessionResponse = {
|
||||
valid: boolean;
|
||||
cookie?: string;
|
||||
};
|
||||
|
||||
export async function transferSession(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const { token } = parsedBody.data;
|
||||
|
||||
const tokenRaw = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(token))
|
||||
);
|
||||
|
||||
const [existing] = await db
|
||||
.select()
|
||||
.from(sessionTransferToken)
|
||||
.where(eq(sessionTransferToken.token, tokenRaw))
|
||||
.innerJoin(
|
||||
sessions,
|
||||
eq(sessions.sessionId, sessionTransferToken.sessionId)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!existing) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Invalid transfer token")
|
||||
);
|
||||
}
|
||||
|
||||
const transferToken = existing.sessionTransferToken;
|
||||
const session = existing.session;
|
||||
|
||||
if (!transferToken) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Invalid transfer token")
|
||||
);
|
||||
}
|
||||
|
||||
await db
|
||||
.delete(sessionTransferToken)
|
||||
.where(eq(sessionTransferToken.token, tokenRaw));
|
||||
|
||||
if (Date.now() > transferToken.expiresAt) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Transfer token expired")
|
||||
);
|
||||
}
|
||||
|
||||
const rawSession = decrypt(
|
||||
transferToken.encryptedSession,
|
||||
config.getRawConfig().server.secret!
|
||||
);
|
||||
|
||||
const isSecure = req.protocol === "https";
|
||||
const cookie = serializeSessionCookie(
|
||||
rawSession,
|
||||
isSecure,
|
||||
new Date(session.expiresAt)
|
||||
);
|
||||
res.appendHeader("Set-Cookie", cookie);
|
||||
|
||||
return response<TransferSessionResponse>(res, {
|
||||
data: { valid: true, cookie },
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Session exchanged successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to exchange session"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
101
server/private/routers/billing/createCheckoutSession.ts
Normal file
101
server/private/routers/billing/createCheckoutSession.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { customers, db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import config from "@server/lib/config";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import stripe from "#private/lib/stripe";
|
||||
import { getLineItems, getStandardFeaturePriceSet } from "@server/lib/billing";
|
||||
import { getTierPriceSet, TierId } from "@server/lib/billing/tiers";
|
||||
|
||||
const createCheckoutSessionSchema = z
|
||||
.object({
|
||||
orgId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export async function createCheckoutSession(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = createCheckoutSessionSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
// check if we already have a customer for this org
|
||||
const [customer] = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
// If we don't have a customer, create one
|
||||
if (!customer) {
|
||||
// error
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"No customer found for this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const standardTierPrice = getTierPriceSet()[TierId.STANDARD];
|
||||
|
||||
const session = await stripe!.checkout.sessions.create({
|
||||
client_reference_id: orgId, // So we can look it up the org later on the webhook
|
||||
billing_address_collection: "required",
|
||||
line_items: [
|
||||
{
|
||||
price: standardTierPrice, // Use the standard tier
|
||||
quantity: 1
|
||||
},
|
||||
...getLineItems(getStandardFeaturePriceSet())
|
||||
], // Start with the standard feature set that matches the free limits
|
||||
customer: customer.customerId,
|
||||
mode: "subscription",
|
||||
success_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing?success=true&session_id={CHECKOUT_SESSION_ID}`,
|
||||
cancel_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing?canceled=true`
|
||||
});
|
||||
|
||||
return response<string>(res, {
|
||||
data: session.url,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Organization created successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
89
server/private/routers/billing/createPortalSession.ts
Normal file
89
server/private/routers/billing/createPortalSession.ts
Normal file
@@ -0,0 +1,89 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { account, customers, db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import config from "@server/lib/config";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import stripe from "#private/lib/stripe";
|
||||
|
||||
const createPortalSessionSchema = z
|
||||
.object({
|
||||
orgId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export async function createPortalSession(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = createPortalSessionSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
// check if we already have a customer for this org
|
||||
const [customer] = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
let customerId: string;
|
||||
// If we don't have a customer, create one
|
||||
if (!customer) {
|
||||
// error
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"No customer found for this organization"
|
||||
)
|
||||
);
|
||||
} else {
|
||||
// If we have a customer, use the existing customer ID
|
||||
customerId = customer.customerId;
|
||||
}
|
||||
const portalSession = await stripe!.billingPortal.sessions.create({
|
||||
customer: customerId,
|
||||
return_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing`
|
||||
});
|
||||
|
||||
return response<string>(res, {
|
||||
data: portalSession.url,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Organization created successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
157
server/private/routers/billing/getOrgSubscription.ts
Normal file
157
server/private/routers/billing/getOrgSubscription.ts
Normal file
@@ -0,0 +1,157 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { Org, orgs } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromZodError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
// Import tables for billing
|
||||
import {
|
||||
customers,
|
||||
subscriptions,
|
||||
subscriptionItems,
|
||||
Subscription,
|
||||
SubscriptionItem
|
||||
} from "@server/db";
|
||||
|
||||
const getOrgSchema = z
|
||||
.object({
|
||||
orgId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type GetOrgSubscriptionResponse = {
|
||||
subscription: Subscription | null;
|
||||
items: SubscriptionItem[];
|
||||
};
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/billing/subscription",
|
||||
description: "Get an organization",
|
||||
tags: [OpenAPITags.Org],
|
||||
request: {
|
||||
params: getOrgSchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function getOrgSubscription(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = getOrgSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromZodError(parsedParams.error)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
let subscriptionData = null;
|
||||
let itemsData: SubscriptionItem[] = [];
|
||||
try {
|
||||
const { subscription, items } = await getOrgSubscriptionData(orgId);
|
||||
subscriptionData = subscription;
|
||||
itemsData = items;
|
||||
} catch (err) {
|
||||
if ((err as Error).message === "Not found") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Organization with ID ${orgId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
|
||||
return response<GetOrgSubscriptionResponse>(res, {
|
||||
data: {
|
||||
subscription: subscriptionData,
|
||||
items: itemsData
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Organization and subscription retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export async function getOrgSubscriptionData(
|
||||
orgId: string
|
||||
): Promise<{ subscription: Subscription | null; items: SubscriptionItem[] }> {
|
||||
const org = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
if (org.length === 0) {
|
||||
throw new Error(`Not found`);
|
||||
}
|
||||
|
||||
// Get customer for org
|
||||
const customer = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
let subscription = null;
|
||||
let items: SubscriptionItem[] = [];
|
||||
|
||||
if (customer.length > 0) {
|
||||
// Get subscription for customer
|
||||
const subs = await db
|
||||
.select()
|
||||
.from(subscriptions)
|
||||
.where(eq(subscriptions.customerId, customer[0].customerId))
|
||||
.limit(1);
|
||||
|
||||
if (subs.length > 0) {
|
||||
subscription = subs[0];
|
||||
// Get subscription items
|
||||
items = await db
|
||||
.select()
|
||||
.from(subscriptionItems)
|
||||
.where(
|
||||
eq(
|
||||
subscriptionItems.subscriptionId,
|
||||
subscription.subscriptionId
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return { subscription, items };
|
||||
}
|
||||
129
server/private/routers/billing/getOrgUsage.ts
Normal file
129
server/private/routers/billing/getOrgUsage.ts
Normal file
@@ -0,0 +1,129 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { orgs } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromZodError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { Limit, limits, Usage, usage } from "@server/db";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
|
||||
const getOrgSchema = z
|
||||
.object({
|
||||
orgId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type GetOrgUsageResponse = {
|
||||
usage: Usage[];
|
||||
limits: Limit[];
|
||||
};
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/billing/usage",
|
||||
description: "Get an organization's billing usage",
|
||||
tags: [OpenAPITags.Org],
|
||||
request: {
|
||||
params: getOrgSchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function getOrgUsage(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = getOrgSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromZodError(parsedParams.error)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
const org = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
if (org.length === 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Organization with ID ${orgId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Get usage for org
|
||||
const usageData = [];
|
||||
|
||||
const siteUptime = await usageService.getUsage(orgId, FeatureId.SITE_UPTIME);
|
||||
const users = await usageService.getUsageDaily(orgId, FeatureId.USERS);
|
||||
const domains = await usageService.getUsageDaily(orgId, FeatureId.DOMAINS);
|
||||
const remoteExitNodes = await usageService.getUsageDaily(orgId, FeatureId.REMOTE_EXIT_NODES);
|
||||
const egressData = await usageService.getUsage(orgId, FeatureId.EGRESS_DATA_MB);
|
||||
|
||||
if (siteUptime) {
|
||||
usageData.push(siteUptime);
|
||||
}
|
||||
if (users) {
|
||||
usageData.push(users);
|
||||
}
|
||||
if (egressData) {
|
||||
usageData.push(egressData);
|
||||
}
|
||||
if (domains) {
|
||||
usageData.push(domains);
|
||||
}
|
||||
if (remoteExitNodes) {
|
||||
usageData.push(remoteExitNodes);
|
||||
}
|
||||
|
||||
const orgLimits = await db.select()
|
||||
.from(limits)
|
||||
.where(eq(limits.orgId, orgId));
|
||||
|
||||
return response<GetOrgUsageResponse>(res, {
|
||||
data: {
|
||||
usage: usageData,
|
||||
limits: orgLimits
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Organization usage retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import Stripe from "stripe";
|
||||
import { customers, db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export async function handleCustomerCreated(
|
||||
customer: Stripe.Customer
|
||||
): Promise<void> {
|
||||
try {
|
||||
const [existingCustomer] = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.customerId, customer.id))
|
||||
.limit(1);
|
||||
|
||||
if (existingCustomer) {
|
||||
logger.info(`Customer with ID ${customer.id} already exists.`);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!customer.metadata.orgId) {
|
||||
logger.error(
|
||||
`Customer with ID ${customer.id} does not have an orgId in metadata.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
await db.insert(customers).values({
|
||||
customerId: customer.id,
|
||||
orgId: customer.metadata.orgId,
|
||||
email: customer.email || null,
|
||||
name: customer.name || null,
|
||||
createdAt: customer.created,
|
||||
updatedAt: customer.created
|
||||
});
|
||||
logger.info(`Customer with ID ${customer.id} created successfully.`);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error handling customer created event for ID ${customer.id}:`,
|
||||
error
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import Stripe from "stripe";
|
||||
import { customers, db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export async function handleCustomerDeleted(
|
||||
customer: Stripe.Customer
|
||||
): Promise<void> {
|
||||
try {
|
||||
const [existingCustomer] = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.customerId, customer.id))
|
||||
.limit(1);
|
||||
|
||||
if (!existingCustomer) {
|
||||
logger.info(`Customer with ID ${customer.id} does not exist.`);
|
||||
return;
|
||||
}
|
||||
|
||||
await db
|
||||
.delete(customers)
|
||||
.where(eq(customers.customerId, customer.id));
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error handling customer created event for ID ${customer.id}:`,
|
||||
error
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import Stripe from "stripe";
|
||||
import { customers, db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export async function handleCustomerUpdated(
|
||||
customer: Stripe.Customer
|
||||
): Promise<void> {
|
||||
try {
|
||||
const [existingCustomer] = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.customerId, customer.id))
|
||||
.limit(1);
|
||||
|
||||
if (!existingCustomer) {
|
||||
logger.info(`Customer with ID ${customer.id} does not exist.`);
|
||||
return;
|
||||
}
|
||||
|
||||
const newCustomer = {
|
||||
customerId: customer.id,
|
||||
orgId: customer.metadata.orgId,
|
||||
email: customer.email || null,
|
||||
name: customer.name || null,
|
||||
updatedAt: Math.floor(Date.now() / 1000)
|
||||
};
|
||||
|
||||
// Update the existing customer record
|
||||
await db
|
||||
.update(customers)
|
||||
.set(newCustomer)
|
||||
.where(eq(customers.customerId, customer.id));
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error handling customer created event for ID ${customer.id}:`,
|
||||
error
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import Stripe from "stripe";
|
||||
import {
|
||||
customers,
|
||||
subscriptions,
|
||||
db,
|
||||
subscriptionItems,
|
||||
userOrgs,
|
||||
users
|
||||
} from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import stripe from "#private/lib/stripe";
|
||||
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
|
||||
import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
|
||||
|
||||
export async function handleSubscriptionCreated(
|
||||
subscription: Stripe.Subscription
|
||||
): Promise<void> {
|
||||
try {
|
||||
// Fetch the subscription from Stripe with expanded price.tiers
|
||||
const fullSubscription = await stripe!.subscriptions.retrieve(
|
||||
subscription.id,
|
||||
{
|
||||
expand: ["items.data.price.tiers"]
|
||||
}
|
||||
);
|
||||
|
||||
logger.info(JSON.stringify(fullSubscription, null, 2));
|
||||
// Check if subscription already exists
|
||||
const [existingSubscription] = await db
|
||||
.select()
|
||||
.from(subscriptions)
|
||||
.where(eq(subscriptions.subscriptionId, subscription.id))
|
||||
.limit(1);
|
||||
|
||||
if (existingSubscription) {
|
||||
logger.info(
|
||||
`Subscription with ID ${subscription.id} already exists.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const newSubscription = {
|
||||
subscriptionId: subscription.id,
|
||||
customerId: subscription.customer as string,
|
||||
status: subscription.status,
|
||||
canceledAt: subscription.canceled_at
|
||||
? subscription.canceled_at
|
||||
: null,
|
||||
createdAt: subscription.created
|
||||
};
|
||||
|
||||
await db.insert(subscriptions).values(newSubscription);
|
||||
logger.info(
|
||||
`Subscription with ID ${subscription.id} created successfully.`
|
||||
);
|
||||
|
||||
// Insert subscription items
|
||||
if (Array.isArray(fullSubscription.items?.data)) {
|
||||
const itemsToInsertPromises = fullSubscription.items.data.map(
|
||||
async (item) => {
|
||||
// try to get the product name from stripe and add it to the item
|
||||
let name = null;
|
||||
if (item.price.product) {
|
||||
const product = await stripe!.products.retrieve(
|
||||
item.price.product as string
|
||||
);
|
||||
name = product.name || null;
|
||||
}
|
||||
|
||||
return {
|
||||
subscriptionId: subscription.id,
|
||||
planId: item.plan.id,
|
||||
priceId: item.price.id,
|
||||
meterId: item.plan.meter,
|
||||
unitAmount: item.price.unit_amount || 0,
|
||||
currentPeriodStart: item.current_period_start,
|
||||
currentPeriodEnd: item.current_period_end,
|
||||
tiers: item.price.tiers
|
||||
? JSON.stringify(item.price.tiers)
|
||||
: null,
|
||||
interval: item.plan.interval,
|
||||
name
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
// wait for all items to be processed
|
||||
const itemsToInsert = await Promise.all(itemsToInsertPromises);
|
||||
|
||||
if (itemsToInsert.length > 0) {
|
||||
await db.insert(subscriptionItems).values(itemsToInsert);
|
||||
logger.info(
|
||||
`Inserted ${itemsToInsert.length} subscription items for subscription ${subscription.id}.`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup customer to get orgId
|
||||
const [customer] = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.customerId, subscription.customer as string))
|
||||
.limit(1);
|
||||
|
||||
if (!customer) {
|
||||
logger.error(
|
||||
`Customer with ID ${subscription.customer} not found for subscription ${subscription.id}.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
await handleSubscriptionLifesycle(customer.orgId, subscription.status);
|
||||
|
||||
const [orgUserRes] = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.orgId, customer.orgId),
|
||||
eq(userOrgs.isOwner, true)
|
||||
)
|
||||
)
|
||||
.innerJoin(users, eq(userOrgs.userId, users.userId));
|
||||
|
||||
if (orgUserRes) {
|
||||
const email = orgUserRes.user.email;
|
||||
|
||||
if (email) {
|
||||
moveEmailToAudience(email, AudienceIds.Subscribed);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error handling subscription created event for ID ${subscription.id}:`,
|
||||
error
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import Stripe from "stripe";
|
||||
import { subscriptions, db, subscriptionItems, customers, userOrgs, users } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
|
||||
import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
|
||||
|
||||
export async function handleSubscriptionDeleted(
|
||||
subscription: Stripe.Subscription
|
||||
): Promise<void> {
|
||||
try {
|
||||
const [existingSubscription] = await db
|
||||
.select()
|
||||
.from(subscriptions)
|
||||
.where(eq(subscriptions.subscriptionId, subscription.id))
|
||||
.limit(1);
|
||||
|
||||
if (!existingSubscription) {
|
||||
logger.info(
|
||||
`Subscription with ID ${subscription.id} does not exist.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
await db
|
||||
.delete(subscriptions)
|
||||
.where(eq(subscriptions.subscriptionId, subscription.id));
|
||||
|
||||
await db
|
||||
.delete(subscriptionItems)
|
||||
.where(eq(subscriptionItems.subscriptionId, subscription.id));
|
||||
|
||||
|
||||
// Lookup customer to get orgId
|
||||
const [customer] = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.customerId, subscription.customer as string))
|
||||
.limit(1);
|
||||
|
||||
if (!customer) {
|
||||
logger.error(
|
||||
`Customer with ID ${subscription.customer} not found for subscription ${subscription.id}.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
await handleSubscriptionLifesycle(
|
||||
customer.orgId,
|
||||
subscription.status
|
||||
);
|
||||
|
||||
const [orgUserRes] = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.orgId, customer.orgId),
|
||||
eq(userOrgs.isOwner, true)
|
||||
)
|
||||
)
|
||||
.innerJoin(users, eq(userOrgs.userId, users.userId));
|
||||
|
||||
if (orgUserRes) {
|
||||
const email = orgUserRes.user.email;
|
||||
|
||||
if (email) {
|
||||
moveEmailToAudience(email, AudienceIds.Churned);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error handling subscription updated event for ID ${subscription.id}:`,
|
||||
error
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -0,0 +1,296 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import Stripe from "stripe";
|
||||
import {
|
||||
subscriptions,
|
||||
db,
|
||||
subscriptionItems,
|
||||
usage,
|
||||
sites,
|
||||
customers,
|
||||
orgs
|
||||
} from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { getFeatureIdByMetricId } from "@server/lib/billing/features";
|
||||
import stripe from "#private/lib/stripe";
|
||||
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
|
||||
|
||||
export async function handleSubscriptionUpdated(
|
||||
subscription: Stripe.Subscription,
|
||||
previousAttributes: Partial<Stripe.Subscription> | undefined
|
||||
): Promise<void> {
|
||||
try {
|
||||
// Fetch the subscription from Stripe with expanded price.tiers
|
||||
const fullSubscription = await stripe!.subscriptions.retrieve(
|
||||
subscription.id,
|
||||
{
|
||||
expand: ["items.data.price.tiers"]
|
||||
}
|
||||
);
|
||||
|
||||
logger.info(JSON.stringify(fullSubscription, null, 2));
|
||||
|
||||
const [existingSubscription] = await db
|
||||
.select()
|
||||
.from(subscriptions)
|
||||
.where(eq(subscriptions.subscriptionId, subscription.id))
|
||||
.limit(1);
|
||||
|
||||
if (!existingSubscription) {
|
||||
logger.info(
|
||||
`Subscription with ID ${subscription.id} does not exist.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// get the customer
|
||||
const [existingCustomer] = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.customerId, subscription.customer as string))
|
||||
.limit(1);
|
||||
|
||||
await db
|
||||
.update(subscriptions)
|
||||
.set({
|
||||
status: subscription.status,
|
||||
canceledAt: subscription.canceled_at
|
||||
? subscription.canceled_at
|
||||
: null,
|
||||
updatedAt: Math.floor(Date.now() / 1000),
|
||||
billingCycleAnchor: subscription.billing_cycle_anchor
|
||||
})
|
||||
.where(eq(subscriptions.subscriptionId, subscription.id));
|
||||
|
||||
await handleSubscriptionLifesycle(
|
||||
existingCustomer.orgId,
|
||||
subscription.status
|
||||
);
|
||||
|
||||
// Upsert subscription items
|
||||
if (Array.isArray(fullSubscription.items?.data)) {
|
||||
const itemsToUpsert = fullSubscription.items.data.map((item) => ({
|
||||
subscriptionId: subscription.id,
|
||||
planId: item.plan.id,
|
||||
priceId: item.price.id,
|
||||
meterId: item.plan.meter,
|
||||
unitAmount: item.price.unit_amount || 0,
|
||||
currentPeriodStart: item.current_period_start,
|
||||
currentPeriodEnd: item.current_period_end,
|
||||
tiers: item.price.tiers
|
||||
? JSON.stringify(item.price.tiers)
|
||||
: null,
|
||||
interval: item.plan.interval
|
||||
}));
|
||||
if (itemsToUpsert.length > 0) {
|
||||
await db.transaction(async (trx) => {
|
||||
await trx
|
||||
.delete(subscriptionItems)
|
||||
.where(
|
||||
eq(
|
||||
subscriptionItems.subscriptionId,
|
||||
subscription.id
|
||||
)
|
||||
);
|
||||
|
||||
await trx.insert(subscriptionItems).values(itemsToUpsert);
|
||||
});
|
||||
logger.info(
|
||||
`Updated ${itemsToUpsert.length} subscription items for subscription ${subscription.id}.`
|
||||
);
|
||||
}
|
||||
|
||||
// --- Detect cycled items and update usage ---
|
||||
if (previousAttributes) {
|
||||
// Only proceed if latest_invoice changed (per Stripe docs)
|
||||
if ("latest_invoice" in previousAttributes) {
|
||||
// If items array present in previous_attributes, check each item
|
||||
if (Array.isArray(previousAttributes.items?.data)) {
|
||||
for (const item of subscription.items.data) {
|
||||
const prevItem = previousAttributes.items.data.find(
|
||||
(pi: any) => pi.id === item.id
|
||||
);
|
||||
if (
|
||||
prevItem &&
|
||||
prevItem.current_period_end &&
|
||||
item.current_period_start &&
|
||||
prevItem.current_period_end ===
|
||||
item.current_period_start &&
|
||||
item.current_period_start >
|
||||
prevItem.current_period_start
|
||||
) {
|
||||
logger.info(
|
||||
`Subscription item ${item.id} has cycled. Resetting usage.`
|
||||
);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
// This item has cycled
|
||||
const meterId = item.plan.meter;
|
||||
if (!meterId) {
|
||||
logger.warn(
|
||||
`No meterId found for subscription item ${item.id}. Skipping usage reset.`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
const featureId = getFeatureIdByMetricId(meterId);
|
||||
if (!featureId) {
|
||||
logger.warn(
|
||||
`No featureId found for meterId ${meterId}. Skipping usage reset.`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const orgId = existingCustomer.orgId;
|
||||
|
||||
if (!orgId) {
|
||||
logger.warn(
|
||||
`No orgId found in subscription metadata for subscription ${subscription.id}. Skipping usage reset.`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
const [usageRow] = await trx
|
||||
.select()
|
||||
.from(usage)
|
||||
.where(
|
||||
eq(
|
||||
usage.usageId,
|
||||
`${orgId}-${featureId}`
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (usageRow) {
|
||||
// get the next rollover date
|
||||
|
||||
const [org] = await trx
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
const lastRollover = usageRow.rolledOverAt
|
||||
? new Date(usageRow.rolledOverAt * 1000)
|
||||
: new Date();
|
||||
const anchorDate = org.createdAt
|
||||
? new Date(org.createdAt)
|
||||
: new Date();
|
||||
|
||||
const nextRollover =
|
||||
calculateNextRollOverDate(
|
||||
lastRollover,
|
||||
anchorDate
|
||||
);
|
||||
|
||||
await trx
|
||||
.update(usage)
|
||||
.set({
|
||||
previousValue: usageRow.latestValue,
|
||||
latestValue:
|
||||
usageRow.instantaneousValue ||
|
||||
0,
|
||||
updatedAt: Math.floor(
|
||||
Date.now() / 1000
|
||||
),
|
||||
rolledOverAt: Math.floor(
|
||||
Date.now() / 1000
|
||||
),
|
||||
nextRolloverAt: Math.floor(
|
||||
nextRollover.getTime() / 1000
|
||||
)
|
||||
})
|
||||
.where(
|
||||
eq(usage.usageId, usageRow.usageId)
|
||||
);
|
||||
logger.info(
|
||||
`Usage reset for org ${orgId}, meter ${featureId} on subscription item cycle.`
|
||||
);
|
||||
}
|
||||
|
||||
// Also reset the sites to 0
|
||||
await trx
|
||||
.update(sites)
|
||||
.set({
|
||||
megabytesIn: 0,
|
||||
megabytesOut: 0
|
||||
})
|
||||
.where(eq(sites.orgId, orgId));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// --- end usage update ---
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error handling subscription updated event for ID ${subscription.id}:`,
|
||||
error
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the next billing date based on monthly billing cycle
|
||||
* Handles end-of-month scenarios as described in the requirements
|
||||
* Made public for testing
|
||||
*/
|
||||
function calculateNextRollOverDate(lastRollover: Date, anchorDate: Date): Date {
|
||||
const rolloverDate = new Date(lastRollover);
|
||||
const anchor = new Date(anchorDate);
|
||||
|
||||
// Get components from rollover date
|
||||
const rolloverYear = rolloverDate.getUTCFullYear();
|
||||
const rolloverMonth = rolloverDate.getUTCMonth();
|
||||
|
||||
// Calculate target month and year (next month)
|
||||
let targetMonth = rolloverMonth + 1;
|
||||
let targetYear = rolloverYear;
|
||||
|
||||
if (targetMonth > 11) {
|
||||
targetMonth = 0;
|
||||
targetYear++;
|
||||
}
|
||||
|
||||
// Get anchor day for billing
|
||||
const anchorDay = anchor.getUTCDate();
|
||||
|
||||
// Get the last day of the target month
|
||||
const lastDayOfMonth = new Date(
|
||||
Date.UTC(targetYear, targetMonth + 1, 0)
|
||||
).getUTCDate();
|
||||
|
||||
// Use the anchor day or the last day of the month, whichever is smaller
|
||||
const targetDay = Math.min(anchorDay, lastDayOfMonth);
|
||||
|
||||
// Create the next billing date using UTC
|
||||
const nextBilling = new Date(
|
||||
Date.UTC(
|
||||
targetYear,
|
||||
targetMonth,
|
||||
targetDay,
|
||||
anchor.getUTCHours(),
|
||||
anchor.getUTCMinutes(),
|
||||
anchor.getUTCSeconds(),
|
||||
anchor.getUTCMilliseconds()
|
||||
)
|
||||
);
|
||||
|
||||
return nextBilling;
|
||||
}
|
||||
18
server/private/routers/billing/index.ts
Normal file
18
server/private/routers/billing/index.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./createCheckoutSession";
|
||||
export * from "./createPortalSession";
|
||||
export * from "./getOrgSubscription";
|
||||
export * from "./getOrgUsage";
|
||||
export * from "./internalGetOrgTier";
|
||||
87
server/private/routers/billing/internalGetOrgTier.ts
Normal file
87
server/private/routers/billing/internalGetOrgTier.ts
Normal file
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromZodError } from "zod-validation-error";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
|
||||
const getOrgSchema = z
|
||||
.object({
|
||||
orgId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type GetOrgTierResponse = {
|
||||
tier: string | null;
|
||||
active: boolean;
|
||||
};
|
||||
|
||||
export async function getOrgTier(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = getOrgSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromZodError(parsedParams.error)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
let tierData = null;
|
||||
let activeData = false;
|
||||
|
||||
try {
|
||||
const { tier, active } = await getOrgTierData(orgId);
|
||||
tierData = tier;
|
||||
activeData = active;
|
||||
} catch (err) {
|
||||
if ((err as Error).message === "Not found") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Organization with ID ${orgId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
|
||||
return response<GetOrgTierResponse>(res, {
|
||||
data: {
|
||||
tier: tierData,
|
||||
active: activeData
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Organization and subscription retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
45
server/private/routers/billing/subscriptionLifecycle.ts
Normal file
45
server/private/routers/billing/subscriptionLifecycle.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { freeLimitSet, limitsService, subscribedLimitSet } from "@server/lib/billing";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export async function handleSubscriptionLifesycle(orgId: string, status: string) {
|
||||
switch (status) {
|
||||
case "active":
|
||||
await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet);
|
||||
await usageService.checkLimitSet(orgId, true);
|
||||
break;
|
||||
case "canceled":
|
||||
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
|
||||
await usageService.checkLimitSet(orgId, true);
|
||||
break;
|
||||
case "past_due":
|
||||
// Optionally handle past due status, e.g., notify customer
|
||||
break;
|
||||
case "unpaid":
|
||||
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
|
||||
await usageService.checkLimitSet(orgId, true);
|
||||
break;
|
||||
case "incomplete":
|
||||
// Optionally handle incomplete status, e.g., notify customer
|
||||
break;
|
||||
case "incomplete_expired":
|
||||
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
|
||||
await usageService.checkLimitSet(orgId, true);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
136
server/private/routers/billing/webhooks.ts
Normal file
136
server/private/routers/billing/webhooks.ts
Normal file
@@ -0,0 +1,136 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import stripe from "#private/lib/stripe";
|
||||
import privateConfig from "#private/lib/config";
|
||||
import logger from "@server/logger";
|
||||
import createHttpError from "http-errors";
|
||||
import { response } from "@server/lib/response";
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import Stripe from "stripe";
|
||||
import { handleCustomerCreated } from "./hooks/handleCustomerCreated";
|
||||
import { handleSubscriptionCreated } from "./hooks/handleSubscriptionCreated";
|
||||
import { handleSubscriptionUpdated } from "./hooks/handleSubscriptionUpdated";
|
||||
import { handleCustomerUpdated } from "./hooks/handleCustomerUpdated";
|
||||
import { handleSubscriptionDeleted } from "./hooks/handleSubscriptionDeleted";
|
||||
import { handleCustomerDeleted } from "./hooks/handleCustomerDeleted";
|
||||
|
||||
export async function billingWebhookHandler(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
let event: Stripe.Event = req.body;
|
||||
const endpointSecret = privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
|
||||
if (!endpointSecret) {
|
||||
logger.warn("Stripe webhook secret is not configured. Webhook events will not be priocessed.");
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "")
|
||||
);
|
||||
}
|
||||
|
||||
// Only verify the event if you have an endpoint secret defined.
|
||||
// Otherwise use the basic event deserialized with JSON.parse
|
||||
if (endpointSecret) {
|
||||
// Get the signature sent by Stripe
|
||||
const signature = req.headers["stripe-signature"];
|
||||
|
||||
if (!signature) {
|
||||
logger.info("No stripe signature found in headers.");
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "No stripe signature found in headers")
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
event = stripe!.webhooks.constructEvent(
|
||||
req.body,
|
||||
signature,
|
||||
endpointSecret
|
||||
);
|
||||
} catch (err) {
|
||||
logger.error(`Webhook signature verification failed.`, err);
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "Webhook signature verification failed")
|
||||
);
|
||||
}
|
||||
}
|
||||
let subscription;
|
||||
let previousAttributes;
|
||||
// Handle the event
|
||||
switch (event.type) {
|
||||
case "customer.created":
|
||||
const customer = event.data.object;
|
||||
logger.info("Customer created: ", customer);
|
||||
handleCustomerCreated(customer);
|
||||
break;
|
||||
case "customer.updated":
|
||||
const customerUpdated = event.data.object;
|
||||
logger.info("Customer updated: ", customerUpdated);
|
||||
handleCustomerUpdated(customerUpdated);
|
||||
break;
|
||||
case "customer.deleted":
|
||||
const customerDeleted = event.data.object;
|
||||
logger.info("Customer deleted: ", customerDeleted);
|
||||
handleCustomerDeleted(customerDeleted);
|
||||
break;
|
||||
case "customer.subscription.paused":
|
||||
subscription = event.data.object;
|
||||
previousAttributes = event.data.previous_attributes;
|
||||
handleSubscriptionUpdated(subscription, previousAttributes);
|
||||
break;
|
||||
case "customer.subscription.resumed":
|
||||
subscription = event.data.object;
|
||||
previousAttributes = event.data.previous_attributes;
|
||||
handleSubscriptionUpdated(subscription, previousAttributes);
|
||||
break;
|
||||
case "customer.subscription.deleted":
|
||||
subscription = event.data.object;
|
||||
handleSubscriptionDeleted(subscription);
|
||||
break;
|
||||
case "customer.subscription.created":
|
||||
subscription = event.data.object;
|
||||
handleSubscriptionCreated(subscription);
|
||||
break;
|
||||
case "customer.subscription.updated":
|
||||
subscription = event.data.object;
|
||||
previousAttributes = event.data.previous_attributes;
|
||||
handleSubscriptionUpdated(subscription, previousAttributes);
|
||||
break;
|
||||
case "customer.subscription.trial_will_end":
|
||||
subscription = event.data.object;
|
||||
// Then define and call a method to handle the subscription trial ending.
|
||||
// handleSubscriptionTrialEnding(subscription);
|
||||
break;
|
||||
case "entitlements.active_entitlement_summary.updated":
|
||||
subscription = event.data.object;
|
||||
logger.info(
|
||||
`Active entitlement summary updated for ${subscription}.`
|
||||
);
|
||||
// Then define and call a method to handle active entitlement summary updated
|
||||
// handleEntitlementUpdated(subscription);
|
||||
break;
|
||||
default:
|
||||
// Unexpected event type
|
||||
logger.info(`Unhandled event type ${event.type}.`);
|
||||
}
|
||||
// Return a 200 response to acknowledge receipt of the event
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Webhook event processed successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
}
|
||||
85
server/private/routers/certificates/createCertificate.ts
Normal file
85
server/private/routers/certificates/createCertificate.ts
Normal file
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Certificate, certificates, db, domains } from "@server/db";
|
||||
import logger from "@server/logger";
|
||||
import { Transaction } from "@server/db";
|
||||
import { eq, or, and, like } from "drizzle-orm";
|
||||
import { build } from "@server/build";
|
||||
|
||||
/**
|
||||
* Checks if a certificate exists for the given domain.
|
||||
* If not, creates a new certificate in 'pending' state.
|
||||
* Wildcard certs cover subdomains.
|
||||
*/
|
||||
export async function createCertificate(domainId: string, domain: string, trx: Transaction | typeof db) {
|
||||
if (build !== "saas") {
|
||||
return;
|
||||
}
|
||||
|
||||
const [domainRecord] = await trx
|
||||
.select()
|
||||
.from(domains)
|
||||
.where(eq(domains.domainId, domainId))
|
||||
.limit(1);
|
||||
|
||||
if (!domainRecord) {
|
||||
throw new Error(`Domain with ID ${domainId} not found`);
|
||||
}
|
||||
|
||||
let existing: Certificate[] = [];
|
||||
if (domainRecord.type == "ns") {
|
||||
const domainLevelDown = domain.split('.').slice(1).join('.');
|
||||
existing = await trx
|
||||
.select()
|
||||
.from(certificates)
|
||||
.where(
|
||||
and(
|
||||
eq(certificates.domainId, domainId),
|
||||
eq(certificates.wildcard, true), // only NS domains can have wildcard certs
|
||||
or(
|
||||
eq(certificates.domain, domain),
|
||||
eq(certificates.domain, domainLevelDown),
|
||||
)
|
||||
)
|
||||
);
|
||||
} else {
|
||||
// For non-NS domains, we only match exact domain names
|
||||
existing = await trx
|
||||
.select()
|
||||
.from(certificates)
|
||||
.where(
|
||||
and(
|
||||
eq(certificates.domainId, domainId),
|
||||
eq(certificates.domain, domain) // exact match for non-NS domains
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (existing.length > 0) {
|
||||
logger.info(
|
||||
`Certificate already exists for domain ${domain}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// No cert found, create a new one in pending state
|
||||
await trx.insert(certificates).values({
|
||||
domain,
|
||||
domainId,
|
||||
wildcard: domainRecord.type == "ns", // we can only create wildcard certs for NS domains
|
||||
status: "pending",
|
||||
updatedAt: Math.floor(Date.now() / 1000),
|
||||
createdAt: Math.floor(Date.now() / 1000)
|
||||
});
|
||||
}
|
||||
167
server/private/routers/certificates/getCertificate.ts
Normal file
167
server/private/routers/certificates/getCertificate.ts
Normal file
@@ -0,0 +1,167 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { certificates, db, domains } from "@server/db";
|
||||
import { eq, and, or, like } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { registry } from "@server/openApi";
|
||||
|
||||
const getCertificateSchema = z
|
||||
.object({
|
||||
domainId: z.string(),
|
||||
domain: z.string().min(1).max(255),
|
||||
orgId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
async function query(domainId: string, domain: string) {
|
||||
const [domainRecord] = await db
|
||||
.select()
|
||||
.from(domains)
|
||||
.where(eq(domains.domainId, domainId))
|
||||
.limit(1);
|
||||
|
||||
if (!domainRecord) {
|
||||
throw new Error(`Domain with ID ${domainId} not found`);
|
||||
}
|
||||
|
||||
let existing: any[] = [];
|
||||
if (domainRecord.type == "ns") {
|
||||
const domainLevelDown = domain.split('.').slice(1).join('.');
|
||||
|
||||
existing = await db
|
||||
.select({
|
||||
certId: certificates.certId,
|
||||
domain: certificates.domain,
|
||||
wildcard: certificates.wildcard,
|
||||
status: certificates.status,
|
||||
expiresAt: certificates.expiresAt,
|
||||
lastRenewalAttempt: certificates.lastRenewalAttempt,
|
||||
createdAt: certificates.createdAt,
|
||||
updatedAt: certificates.updatedAt,
|
||||
errorMessage: certificates.errorMessage,
|
||||
renewalCount: certificates.renewalCount
|
||||
})
|
||||
.from(certificates)
|
||||
.where(
|
||||
and(
|
||||
eq(certificates.domainId, domainId),
|
||||
eq(certificates.wildcard, true), // only NS domains can have wildcard certs
|
||||
or(
|
||||
eq(certificates.domain, domain),
|
||||
eq(certificates.domain, domainLevelDown),
|
||||
)
|
||||
)
|
||||
);
|
||||
} else {
|
||||
// For non-NS domains, we only match exact domain names
|
||||
existing = await db
|
||||
.select({
|
||||
certId: certificates.certId,
|
||||
domain: certificates.domain,
|
||||
wildcard: certificates.wildcard,
|
||||
status: certificates.status,
|
||||
expiresAt: certificates.expiresAt,
|
||||
lastRenewalAttempt: certificates.lastRenewalAttempt,
|
||||
createdAt: certificates.createdAt,
|
||||
updatedAt: certificates.updatedAt,
|
||||
errorMessage: certificates.errorMessage,
|
||||
renewalCount: certificates.renewalCount
|
||||
})
|
||||
.from(certificates)
|
||||
.where(
|
||||
and(
|
||||
eq(certificates.domainId, domainId),
|
||||
eq(certificates.domain, domain) // exact match for non-NS domains
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return existing.length > 0 ? existing[0] : null;
|
||||
}
|
||||
|
||||
export type GetCertificateResponse = {
|
||||
certId: number;
|
||||
domain: string;
|
||||
domainId: string;
|
||||
wildcard: boolean;
|
||||
status: string; // pending, requested, valid, expired, failed
|
||||
expiresAt: string | null;
|
||||
lastRenewalAttempt: Date | null;
|
||||
createdAt: string;
|
||||
updatedAt: string;
|
||||
errorMessage?: string | null;
|
||||
renewalCount: number;
|
||||
}
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/certificate/{domainId}/{domain}",
|
||||
description: "Get a certificate by domain.",
|
||||
tags: ["Certificate"],
|
||||
request: {
|
||||
params: z.object({
|
||||
domainId: z
|
||||
.string(),
|
||||
domain: z.string().min(1).max(255),
|
||||
orgId: z.string()
|
||||
})
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function getCertificate(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = getCertificateSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { domainId, domain } = parsedParams.data;
|
||||
|
||||
const cert = await query(domainId, domain);
|
||||
|
||||
if (!cert) {
|
||||
logger.warn(`Certificate not found for domain: ${domainId}`);
|
||||
return next(createHttpError(HttpCode.NOT_FOUND, "Certificate not found"));
|
||||
}
|
||||
|
||||
return response<GetCertificateResponse>(res, {
|
||||
data: cert,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Certificate retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
15
server/private/routers/certificates/index.ts
Normal file
15
server/private/routers/certificates/index.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./getCertificate";
|
||||
export * from "./restartCertificate";
|
||||
116
server/private/routers/certificates/restartCertificate.ts
Normal file
116
server/private/routers/certificates/restartCertificate.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { certificates, db } from "@server/db";
|
||||
import { sites } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import stoi from "@server/lib/stoi";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const restartCertificateParamsSchema = z
|
||||
.object({
|
||||
certId: z.string().transform(stoi).pipe(z.number().int().positive()),
|
||||
orgId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/certificate/{certId}",
|
||||
description: "Restart a certificate by ID.",
|
||||
tags: ["Certificate"],
|
||||
request: {
|
||||
params: z.object({
|
||||
certId: z
|
||||
.string()
|
||||
.transform(stoi)
|
||||
.pipe(z.number().int().positive()),
|
||||
orgId: z.string()
|
||||
})
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function restartCertificate(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = restartCertificateParamsSchema.safeParse(
|
||||
req.params
|
||||
);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { certId } = parsedParams.data;
|
||||
|
||||
// get the certificate by ID
|
||||
const [cert] = await db
|
||||
.select()
|
||||
.from(certificates)
|
||||
.where(eq(certificates.certId, certId))
|
||||
.limit(1);
|
||||
|
||||
if (!cert) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Certificate not found")
|
||||
);
|
||||
}
|
||||
|
||||
if (cert.status != "failed" && cert.status != "expired") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Certificate is already valid, no need to restart"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// update the certificate status to 'pending'
|
||||
await db
|
||||
.update(certificates)
|
||||
.set({
|
||||
status: "pending",
|
||||
errorMessage: null,
|
||||
lastRenewalAttempt: Math.floor(Date.now() / 1000)
|
||||
})
|
||||
.where(eq(certificates.certId, certId));
|
||||
|
||||
return response<null>(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Certificate restarted successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { db, domainNamespaces, resources } from "@server/db";
|
||||
import { inArray } from "drizzle-orm";
|
||||
|
||||
const paramsSchema = z.object({}).strict();
|
||||
|
||||
const querySchema = z
|
||||
.object({
|
||||
subdomain: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type CheckDomainAvailabilityResponse = {
|
||||
available: boolean;
|
||||
options: {
|
||||
domainNamespaceId: string;
|
||||
domainId: string;
|
||||
fullDomain: string;
|
||||
}[];
|
||||
};
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
path: "/domain/check-namespace-availability",
|
||||
description: "Check if a domain namespace is available based on subdomain",
|
||||
tags: [OpenAPITags.Domain],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
query: querySchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function checkDomainNamespaceAvailability(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedQuery = querySchema.safeParse(req.query);
|
||||
if (!parsedQuery.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedQuery.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
const { subdomain } = parsedQuery.data;
|
||||
|
||||
const namespaces = await db.select().from(domainNamespaces);
|
||||
let possibleDomains = namespaces.map((ns) => {
|
||||
const desired = `${subdomain}.${ns.domainNamespaceId}`;
|
||||
return {
|
||||
fullDomain: desired,
|
||||
domainId: ns.domainId,
|
||||
domainNamespaceId: ns.domainNamespaceId
|
||||
};
|
||||
});
|
||||
|
||||
if (!possibleDomains.length) {
|
||||
return response<CheckDomainAvailabilityResponse>(res, {
|
||||
data: {
|
||||
available: false,
|
||||
options: []
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "No domain namespaces available",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
const existingResources = await db
|
||||
.select()
|
||||
.from(resources)
|
||||
.where(
|
||||
inArray(
|
||||
resources.fullDomain,
|
||||
possibleDomains.map((d) => d.fullDomain)
|
||||
)
|
||||
);
|
||||
|
||||
possibleDomains = possibleDomains.filter(
|
||||
(domain) =>
|
||||
!existingResources.some(
|
||||
(resource) => resource.fullDomain === domain.fullDomain
|
||||
)
|
||||
);
|
||||
|
||||
return response<CheckDomainAvailabilityResponse>(res, {
|
||||
data: {
|
||||
available: possibleDomains.length > 0,
|
||||
options: possibleDomains
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Domain namespaces checked successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
15
server/private/routers/domain/index.ts
Normal file
15
server/private/routers/domain/index.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./checkDomainNamespaceAvailability";
|
||||
export * from "./listDomainNamespaces";
|
||||
130
server/private/routers/domain/listDomainNamespaces.ts
Normal file
130
server/private/routers/domain/listDomainNamespaces.ts
Normal file
@@ -0,0 +1,130 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, domainNamespaces } from "@server/db";
|
||||
import { domains } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import { eq, sql } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const paramsSchema = z.object({}).strict();
|
||||
|
||||
const querySchema = z
|
||||
.object({
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().nonnegative()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().nonnegative())
|
||||
})
|
||||
.strict();
|
||||
|
||||
async function query(limit: number, offset: number) {
|
||||
const res = await db
|
||||
.select({
|
||||
domainNamespaceId: domainNamespaces.domainNamespaceId,
|
||||
domainId: domainNamespaces.domainId
|
||||
})
|
||||
.from(domainNamespaces)
|
||||
.innerJoin(
|
||||
domains,
|
||||
eq(domains.domainId, domainNamespaces.domainNamespaceId)
|
||||
)
|
||||
.limit(limit)
|
||||
.offset(offset);
|
||||
return res;
|
||||
}
|
||||
|
||||
export type ListDomainNamespacesResponse = {
|
||||
domainNamespaces: NonNullable<Awaited<ReturnType<typeof query>>>;
|
||||
pagination: { total: number; limit: number; offset: number };
|
||||
};
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
path: "/domains/namepaces",
|
||||
description: "List all domain namespaces in the system",
|
||||
tags: [OpenAPITags.Domain],
|
||||
request: {
|
||||
query: querySchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function listDomainNamespaces(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedQuery = querySchema.safeParse(req.query);
|
||||
if (!parsedQuery.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedQuery.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
const { limit, offset } = parsedQuery.data;
|
||||
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const domainNamespacesList = await query(limit, offset);
|
||||
|
||||
const [{ count }] = await db
|
||||
.select({ count: sql<number>`count(*)` })
|
||||
.from(domainNamespaces);
|
||||
|
||||
return response<ListDomainNamespacesResponse>(res, {
|
||||
data: {
|
||||
domainNamespaces: domainNamespacesList,
|
||||
pagination: {
|
||||
total: count,
|
||||
limit,
|
||||
offset
|
||||
}
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Namespaces retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
262
server/private/routers/external.ts
Normal file
262
server/private/routers/external.ts
Normal file
@@ -0,0 +1,262 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import * as certificates from "#private/routers/certificates";
|
||||
import { createStore } from "#private/lib/rateLimitStore";
|
||||
import * as billing from "#private/routers/billing";
|
||||
import * as remoteExitNode from "#private/routers/remoteExitNode";
|
||||
import * as loginPage from "#private/routers/loginPage";
|
||||
import * as orgIdp from "#private/routers/orgIdp";
|
||||
import * as domain from "#private/routers/domain";
|
||||
import * as auth from "#private/routers/auth";
|
||||
|
||||
import { Router } from "express";
|
||||
import { verifyOrgAccess, verifySessionUserMiddleware, verifyUserHasAction } from "@server/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
import {
|
||||
verifyCertificateAccess,
|
||||
verifyIdpAccess,
|
||||
verifyLoginPageAccess,
|
||||
verifyRemoteExitNodeAccess
|
||||
} from "#private/middlewares";
|
||||
import rateLimit, { ipKeyGenerator } from "express-rate-limit";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
|
||||
import { unauthenticated as ua, authenticated as a } from "@server/routers/external";
|
||||
|
||||
export const authenticated = a;
|
||||
export const unauthenticated = ua;
|
||||
|
||||
unauthenticated.post(
|
||||
"/quick-start",
|
||||
rateLimit({
|
||||
windowMs: 15 * 60 * 1000,
|
||||
max: 100,
|
||||
keyGenerator: (req) => req.path,
|
||||
handler: (req, res, next) => {
|
||||
const message = `We're too busy right now. Please try again later.`;
|
||||
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
auth.quickStart
|
||||
);
|
||||
|
||||
unauthenticated.post(
|
||||
"/remote-exit-node/quick-start",
|
||||
rateLimit({
|
||||
windowMs: 60 * 60 * 1000,
|
||||
max: 5,
|
||||
keyGenerator: (req) => `${req.path}:${ipKeyGenerator(req.ip || "")}`,
|
||||
handler: (req, res, next) => {
|
||||
const message = `You can only create 5 remote exit nodes every hour. Please try again later.`;
|
||||
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
remoteExitNode.quickStartRemoteExitNode
|
||||
);
|
||||
|
||||
|
||||
authenticated.put(
|
||||
"/org/:orgId/idp/oidc",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.createIdp),
|
||||
orgIdp.createOrgOidcIdp
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/org/:orgId/idp/:idpId/oidc",
|
||||
verifyOrgAccess,
|
||||
verifyIdpAccess,
|
||||
verifyUserHasAction(ActionsEnum.updateIdp),
|
||||
orgIdp.updateOrgOidcIdp
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
"/org/:orgId/idp/:idpId",
|
||||
verifyOrgAccess,
|
||||
verifyIdpAccess,
|
||||
verifyUserHasAction(ActionsEnum.deleteIdp),
|
||||
orgIdp.deleteOrgIdp
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/idp/:idpId",
|
||||
verifyOrgAccess,
|
||||
verifyIdpAccess,
|
||||
verifyUserHasAction(ActionsEnum.getIdp),
|
||||
orgIdp.getOrgIdp
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/idp",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.listIdps),
|
||||
orgIdp.listOrgIdps
|
||||
);
|
||||
|
||||
authenticated.get("/org/:orgId/idp", orgIdp.listOrgIdps); // anyone can see this; it's just a list of idp names and ids
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/certificate/:domainId/:domain",
|
||||
verifyOrgAccess,
|
||||
verifyCertificateAccess,
|
||||
verifyUserHasAction(ActionsEnum.getCertificate),
|
||||
certificates.getCertificate
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/org/:orgId/certificate/:certId/restart",
|
||||
verifyOrgAccess,
|
||||
verifyCertificateAccess,
|
||||
verifyUserHasAction(ActionsEnum.restartCertificate),
|
||||
certificates.restartCertificate
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/org/:orgId/billing/create-checkout-session",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.billing),
|
||||
billing.createCheckoutSession
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/org/:orgId/billing/create-portal-session",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.billing),
|
||||
billing.createPortalSession
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/billing/subscription",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.billing),
|
||||
billing.getOrgSubscription
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/billing/usage",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.billing),
|
||||
billing.getOrgUsage
|
||||
);
|
||||
|
||||
authenticated.get("/domain/namespaces", domain.listDomainNamespaces);
|
||||
|
||||
authenticated.get(
|
||||
"/domain/check-namespace-availability",
|
||||
domain.checkDomainNamespaceAvailability
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/org/:orgId/remote-exit-node",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.createRemoteExitNode),
|
||||
remoteExitNode.createRemoteExitNode
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/remote-exit-nodes",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.listRemoteExitNode),
|
||||
remoteExitNode.listRemoteExitNodes
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/remote-exit-node/:remoteExitNodeId",
|
||||
verifyOrgAccess,
|
||||
verifyRemoteExitNodeAccess,
|
||||
verifyUserHasAction(ActionsEnum.getRemoteExitNode),
|
||||
remoteExitNode.getRemoteExitNode
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/pick-remote-exit-node-defaults",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.createRemoteExitNode),
|
||||
remoteExitNode.pickRemoteExitNodeDefaults
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
"/org/:orgId/remote-exit-node/:remoteExitNodeId",
|
||||
verifyOrgAccess,
|
||||
verifyRemoteExitNodeAccess,
|
||||
verifyUserHasAction(ActionsEnum.deleteRemoteExitNode),
|
||||
remoteExitNode.deleteRemoteExitNode
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/org/:orgId/login-page",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.createLoginPage),
|
||||
loginPage.createLoginPage
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/org/:orgId/login-page/:loginPageId",
|
||||
verifyOrgAccess,
|
||||
verifyLoginPageAccess,
|
||||
verifyUserHasAction(ActionsEnum.updateLoginPage),
|
||||
loginPage.updateLoginPage
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
"/org/:orgId/login-page/:loginPageId",
|
||||
verifyOrgAccess,
|
||||
verifyLoginPageAccess,
|
||||
verifyUserHasAction(ActionsEnum.deleteLoginPage),
|
||||
loginPage.deleteLoginPage
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/login-page",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.getLoginPage),
|
||||
loginPage.getLoginPage
|
||||
);
|
||||
|
||||
export const authRouter = Router();
|
||||
|
||||
authRouter.post(
|
||||
"/remoteExitNode/get-token",
|
||||
rateLimit({
|
||||
windowMs: 15 * 60 * 1000,
|
||||
max: 900,
|
||||
keyGenerator: (req) =>
|
||||
`remoteExitNodeGetToken:${req.body.newtId || ipKeyGenerator(req.ip || "")}`,
|
||||
handler: (req, res, next) => {
|
||||
const message = `You can only request an remoteExitNodeToken token ${900} times every ${15} minutes. Please try again later.`;
|
||||
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
remoteExitNode.getRemoteExitNodeToken
|
||||
);
|
||||
|
||||
authRouter.post(
|
||||
"/transfer-session-token",
|
||||
rateLimit({
|
||||
windowMs: 1 * 60 * 1000,
|
||||
max: 60,
|
||||
keyGenerator: (req) =>
|
||||
`transferSessionToken:${ipKeyGenerator(req.ip || "")}`,
|
||||
handler: (req, res, next) => {
|
||||
const message = `You can only transfer a session token ${5} times every ${1} minute. Please try again later.`;
|
||||
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
auth.transferSession
|
||||
);
|
||||
67
server/private/routers/gerbil/createExitNode.ts
Normal file
67
server/private/routers/gerbil/createExitNode.ts
Normal file
@@ -0,0 +1,67 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { db, ExitNode, exitNodes } from "@server/db";
|
||||
import { getUniqueExitNodeEndpointName } from "@server/db/names";
|
||||
import config from "@server/lib/config";
|
||||
import { getNextAvailableSubnet } from "@server/lib/exitNodes";
|
||||
import logger from "@server/logger";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
export async function createExitNode(
|
||||
publicKey: string,
|
||||
reachableAt: string | undefined
|
||||
) {
|
||||
// Fetch exit node
|
||||
const [exitNodeQuery] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.publicKey, publicKey));
|
||||
let exitNode: ExitNode;
|
||||
if (!exitNodeQuery) {
|
||||
const address = await getNextAvailableSubnet();
|
||||
// TODO: eventually we will want to get the next available port so that we can multiple exit nodes
|
||||
// const listenPort = await getNextAvailablePort();
|
||||
const listenPort = config.getRawConfig().gerbil.start_port;
|
||||
let subEndpoint = "";
|
||||
if (config.getRawConfig().gerbil.use_subdomain) {
|
||||
subEndpoint = await getUniqueExitNodeEndpointName();
|
||||
}
|
||||
|
||||
const exitNodeName =
|
||||
config.getRawConfig().gerbil.exit_node_name ||
|
||||
`Exit Node ${publicKey.slice(0, 8)}`;
|
||||
|
||||
// create a new exit node
|
||||
[exitNode] = await db
|
||||
.insert(exitNodes)
|
||||
.values({
|
||||
publicKey,
|
||||
endpoint: `${subEndpoint}${subEndpoint != "" ? "." : ""}${config.getRawConfig().gerbil.base_endpoint}`,
|
||||
address,
|
||||
listenPort,
|
||||
reachableAt,
|
||||
name: exitNodeName
|
||||
})
|
||||
.returning()
|
||||
.execute();
|
||||
|
||||
logger.info(
|
||||
`Created new exit node ${exitNode.name} with address ${exitNode.address} and port ${exitNode.listenPort}`
|
||||
);
|
||||
} else {
|
||||
exitNode = exitNodeQuery;
|
||||
}
|
||||
|
||||
return exitNode;
|
||||
}
|
||||
13
server/private/routers/gerbil/receiveBandwidth.ts
Normal file
13
server/private/routers/gerbil/receiveBandwidth.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
1492
server/private/routers/hybrid.ts
Normal file
1492
server/private/routers/hybrid.ts
Normal file
File diff suppressed because it is too large
Load Diff
42
server/private/routers/integration.ts
Normal file
42
server/private/routers/integration.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import * as orgIdp from "#private/routers/orgIdp";
|
||||
import * as org from "#private/routers/org";
|
||||
|
||||
import { Router } from "express";
|
||||
import {
|
||||
verifyApiKey,
|
||||
verifyApiKeyHasAction,
|
||||
verifyApiKeyIsRoot,
|
||||
} from "@server/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
|
||||
import { unauthenticated as ua, authenticated as a } from "@server/routers/integration";
|
||||
|
||||
export const unauthenticated = ua;
|
||||
export const authenticated = a;
|
||||
|
||||
authenticated.post(
|
||||
`/org/:orgId/send-usage-notification`,
|
||||
verifyApiKeyIsRoot, // We are the only ones who can use root key so its fine
|
||||
verifyApiKeyHasAction(ActionsEnum.sendUsageNotification),
|
||||
org.sendUsageNotification
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
"/idp/:idpId",
|
||||
verifyApiKeyIsRoot,
|
||||
verifyApiKeyHasAction(ActionsEnum.deleteIdp),
|
||||
orgIdp.deleteOrgIdp
|
||||
);
|
||||
36
server/private/routers/internal.ts
Normal file
36
server/private/routers/internal.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import * as loginPage from "#private/routers/loginPage";
|
||||
import * as auth from "#private/routers/auth";
|
||||
import * as orgIdp from "#private/routers/orgIdp";
|
||||
import * as billing from "#private/routers/billing";
|
||||
|
||||
import { Router } from "express";
|
||||
import { verifySessionUserMiddleware } from "@server/middlewares";
|
||||
|
||||
import { internalRouter as ir } from "@server/routers/internal";
|
||||
|
||||
export const internalRouter = ir;
|
||||
|
||||
internalRouter.get("/org/:orgId/idp", orgIdp.listOrgIdps);
|
||||
|
||||
internalRouter.get("/org/:orgId/billing/tier", billing.getOrgTier);
|
||||
|
||||
internalRouter.get("/login-page", loginPage.loadLoginPage);
|
||||
|
||||
internalRouter.post(
|
||||
"/get-session-transfer-token",
|
||||
verifySessionUserMiddleware,
|
||||
auth.getSessionTransferToken
|
||||
);
|
||||
225
server/private/routers/loginPage/createLoginPage.ts
Normal file
225
server/private/routers/loginPage/createLoginPage.ts
Normal file
@@ -0,0 +1,225 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import {
|
||||
db,
|
||||
exitNodes,
|
||||
loginPage,
|
||||
LoginPage,
|
||||
loginPageOrg,
|
||||
resources,
|
||||
sites
|
||||
} from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { validateAndConstructDomain } from "@server/lib/domainUtils";
|
||||
import { createCertificate } from "#private/routers/certificates/createCertificate";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { build } from "@server/build";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
orgId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
const bodySchema = z
|
||||
.object({
|
||||
subdomain: z.string().nullable().optional(),
|
||||
domainId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type CreateLoginPageBody = z.infer<typeof bodySchema>;
|
||||
|
||||
export type CreateLoginPageResponse = LoginPage;
|
||||
|
||||
export async function createLoginPage(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { domainId, subdomain } = parsedBody.data;
|
||||
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const [existing] = await db
|
||||
.select()
|
||||
.from(loginPageOrg)
|
||||
.where(eq(loginPageOrg.orgId, orgId));
|
||||
|
||||
if (existing) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"A login page already exists for this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const domainResult = await validateAndConstructDomain(
|
||||
domainId,
|
||||
orgId,
|
||||
subdomain
|
||||
);
|
||||
|
||||
if (!domainResult.success) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, domainResult.error)
|
||||
);
|
||||
}
|
||||
|
||||
const { fullDomain, subdomain: finalSubdomain } = domainResult;
|
||||
|
||||
logger.debug(`Full domain: ${fullDomain}`);
|
||||
|
||||
const existingResource = await db
|
||||
.select()
|
||||
.from(resources)
|
||||
.where(eq(resources.fullDomain, fullDomain));
|
||||
|
||||
if (existingResource.length > 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
"Resource with that domain already exists"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const existingLoginPages = await db
|
||||
.select()
|
||||
.from(loginPage)
|
||||
.where(eq(loginPage.fullDomain, fullDomain));
|
||||
|
||||
if (existingLoginPages.length > 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
"Login page with that domain already exists"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let returned: LoginPage | undefined;
|
||||
await db.transaction(async (trx) => {
|
||||
|
||||
const orgSites = await trx
|
||||
.select()
|
||||
.from(sites)
|
||||
.innerJoin(exitNodes, eq(exitNodes.exitNodeId, sites.exitNodeId))
|
||||
.where(and(eq(sites.orgId, orgId), eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
|
||||
.limit(10);
|
||||
|
||||
let exitNodesList = orgSites.map((s) => s.exitNodes);
|
||||
|
||||
if (exitNodesList.length === 0) {
|
||||
exitNodesList = await trx
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(and(eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
|
||||
.limit(10);
|
||||
}
|
||||
|
||||
// select a random exit node
|
||||
const randomExitNode =
|
||||
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
|
||||
|
||||
if (!randomExitNode) {
|
||||
throw new Error("No exit nodes available");
|
||||
}
|
||||
|
||||
const [returnedLoginPage] = await db
|
||||
.insert(loginPage)
|
||||
.values({
|
||||
subdomain: finalSubdomain,
|
||||
fullDomain,
|
||||
domainId,
|
||||
exitNodeId: randomExitNode.exitNodeId
|
||||
})
|
||||
.returning();
|
||||
|
||||
await trx.insert(loginPageOrg).values({
|
||||
orgId,
|
||||
loginPageId: returnedLoginPage.loginPageId
|
||||
});
|
||||
|
||||
returned = returnedLoginPage;
|
||||
});
|
||||
|
||||
if (!returned) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to create login page"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await createCertificate(domainId, fullDomain, db);
|
||||
|
||||
return response<LoginPage>(res, {
|
||||
data: returned,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Login page created successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
106
server/private/routers/loginPage/deleteLoginPage.ts
Normal file
106
server/private/routers/loginPage/deleteLoginPage.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, loginPage, LoginPage, loginPageOrg } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
orgId: z.string(),
|
||||
loginPageId: z.coerce.number()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type DeleteLoginPageResponse = LoginPage;
|
||||
|
||||
export async function deleteLoginPage(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [existingLoginPage] = await db
|
||||
.select()
|
||||
.from(loginPage)
|
||||
.where(eq(loginPage.loginPageId, parsedParams.data.loginPageId))
|
||||
.innerJoin(
|
||||
loginPageOrg,
|
||||
eq(loginPageOrg.orgId, parsedParams.data.orgId)
|
||||
);
|
||||
|
||||
if (!existingLoginPage) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Login page not found")
|
||||
);
|
||||
}
|
||||
|
||||
await db
|
||||
.delete(loginPageOrg)
|
||||
.where(
|
||||
and(
|
||||
eq(loginPageOrg.orgId, parsedParams.data.orgId),
|
||||
eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId)
|
||||
)
|
||||
);
|
||||
|
||||
// const leftoverLinks = await db
|
||||
// .select()
|
||||
// .from(loginPageOrg)
|
||||
// .where(eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId))
|
||||
// .limit(1);
|
||||
|
||||
// if (!leftoverLinks.length) {
|
||||
await db
|
||||
.delete(loginPage)
|
||||
.where(
|
||||
eq(loginPage.loginPageId, parsedParams.data.loginPageId)
|
||||
);
|
||||
|
||||
await db
|
||||
.delete(loginPageOrg)
|
||||
.where(
|
||||
eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId)
|
||||
);
|
||||
// }
|
||||
|
||||
return response<LoginPage>(res, {
|
||||
data: existingLoginPage.loginPage,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Login page deleted successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
86
server/private/routers/loginPage/getLoginPage.ts
Normal file
86
server/private/routers/loginPage/getLoginPage.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, loginPage, loginPageOrg } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
orgId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
async function query(orgId: string) {
|
||||
const [res] = await db
|
||||
.select()
|
||||
.from(loginPageOrg)
|
||||
.where(eq(loginPageOrg.orgId, orgId))
|
||||
.innerJoin(
|
||||
loginPage,
|
||||
eq(loginPage.loginPageId, loginPageOrg.loginPageId)
|
||||
)
|
||||
.limit(1);
|
||||
return res?.loginPage;
|
||||
}
|
||||
|
||||
export type GetLoginPageResponse = NonNullable<
|
||||
Awaited<ReturnType<typeof query>>
|
||||
>;
|
||||
|
||||
export async function getLoginPage(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
const loginPage = await query(orgId);
|
||||
|
||||
if (!loginPage) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Login page not found")
|
||||
);
|
||||
}
|
||||
|
||||
return response<GetLoginPageResponse>(res, {
|
||||
data: loginPage,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Login page retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
19
server/private/routers/loginPage/index.ts
Normal file
19
server/private/routers/loginPage/index.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./createLoginPage";
|
||||
export * from "./updateLoginPage";
|
||||
export * from "./getLoginPage";
|
||||
export * from "./loadLoginPage";
|
||||
export * from "./updateLoginPage";
|
||||
export * from "./deleteLoginPage";
|
||||
148
server/private/routers/loginPage/loadLoginPage.ts
Normal file
148
server/private/routers/loginPage/loadLoginPage.ts
Normal file
@@ -0,0 +1,148 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, idpOrg, loginPage, loginPageOrg, resources } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
|
||||
const querySchema = z.object({
|
||||
resourceId: z.coerce.number().int().positive().optional(),
|
||||
idpId: z.coerce.number().int().positive().optional(),
|
||||
orgId: z.coerce.number().int().positive().optional(),
|
||||
fullDomain: z.string().min(1)
|
||||
});
|
||||
|
||||
async function query(orgId: string | undefined, fullDomain: string) {
|
||||
if (!orgId) {
|
||||
const [res] = await db
|
||||
.select()
|
||||
.from(loginPage)
|
||||
.where(eq(loginPage.fullDomain, fullDomain))
|
||||
.innerJoin(
|
||||
loginPageOrg,
|
||||
eq(loginPage.loginPageId, loginPageOrg.loginPageId)
|
||||
)
|
||||
.limit(1);
|
||||
return {
|
||||
...res.loginPage,
|
||||
orgId: res.loginPageOrg.orgId
|
||||
};
|
||||
}
|
||||
|
||||
const [orgLink] = await db
|
||||
.select()
|
||||
.from(loginPageOrg)
|
||||
.where(eq(loginPageOrg.orgId, orgId));
|
||||
|
||||
if (!orgLink) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const [res] = await db
|
||||
.select()
|
||||
.from(loginPage)
|
||||
.where(
|
||||
and(
|
||||
eq(loginPage.loginPageId, orgLink.loginPageId),
|
||||
eq(loginPage.fullDomain, fullDomain)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
return {
|
||||
...res,
|
||||
orgId: orgLink.orgId
|
||||
};
|
||||
}
|
||||
|
||||
export type LoadLoginPageResponse = NonNullable<
|
||||
Awaited<ReturnType<typeof query>>
|
||||
> & { orgId: string };
|
||||
|
||||
export async function loadLoginPage(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedQuery = querySchema.safeParse(req.query);
|
||||
if (!parsedQuery.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedQuery.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { resourceId, idpId, fullDomain } = parsedQuery.data;
|
||||
|
||||
let orgId;
|
||||
if (resourceId) {
|
||||
const [resource] = await db
|
||||
.select()
|
||||
.from(resources)
|
||||
.where(eq(resources.resourceId, resourceId))
|
||||
.limit(1);
|
||||
|
||||
if (!resource) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Resource not found")
|
||||
);
|
||||
}
|
||||
|
||||
orgId = resource.orgId;
|
||||
} else if (idpId) {
|
||||
const [idpOrgLink] = await db
|
||||
.select()
|
||||
.from(idpOrg)
|
||||
.where(eq(idpOrg.idpId, idpId));
|
||||
|
||||
if (!idpOrgLink) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "IdP not found")
|
||||
);
|
||||
}
|
||||
|
||||
orgId = idpOrgLink.orgId;
|
||||
} else if (parsedQuery.data.orgId) {
|
||||
orgId = parsedQuery.data.orgId.toString();
|
||||
}
|
||||
|
||||
const loginPage = await query(orgId, fullDomain);
|
||||
|
||||
if (!loginPage) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Login page not found")
|
||||
);
|
||||
}
|
||||
|
||||
return response<LoadLoginPageResponse>(res, {
|
||||
data: loginPage,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Login page retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
227
server/private/routers/loginPage/updateLoginPage.ts
Normal file
227
server/private/routers/loginPage/updateLoginPage.ts
Normal file
@@ -0,0 +1,227 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, loginPage, LoginPage, loginPageOrg, resources } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { validateAndConstructDomain } from "@server/lib/domainUtils";
|
||||
import { subdomainSchema } from "@server/lib/schemas";
|
||||
import { createCertificate } from "#private/routers/certificates/createCertificate";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { build } from "@server/build";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
orgId: z.string(),
|
||||
loginPageId: z.coerce.number()
|
||||
})
|
||||
.strict();
|
||||
|
||||
const bodySchema = z
|
||||
.object({
|
||||
subdomain: subdomainSchema.nullable().optional(),
|
||||
domainId: z.string().optional()
|
||||
})
|
||||
.strict()
|
||||
.refine((data) => Object.keys(data).length > 0, {
|
||||
message: "At least one field must be provided for update"
|
||||
})
|
||||
.refine(
|
||||
(data) => {
|
||||
if (data.subdomain) {
|
||||
return subdomainSchema.safeParse(data.subdomain).success;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: "Invalid subdomain" }
|
||||
);
|
||||
|
||||
export type UpdateLoginPageBody = z.infer<typeof bodySchema>;
|
||||
|
||||
export type UpdateLoginPageResponse = LoginPage;
|
||||
|
||||
export async function updateLoginPage(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const updateData = parsedBody.data;
|
||||
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { loginPageId, orgId } = parsedParams.data;
|
||||
|
||||
if (build === "saas"){
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const [existingLoginPage] = await db
|
||||
.select()
|
||||
.from(loginPage)
|
||||
.where(eq(loginPage.loginPageId, loginPageId));
|
||||
|
||||
if (!existingLoginPage) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Login page not found")
|
||||
);
|
||||
}
|
||||
|
||||
const [orgLink] = await db
|
||||
.select()
|
||||
.from(loginPageOrg)
|
||||
.where(
|
||||
and(
|
||||
eq(loginPageOrg.orgId, orgId),
|
||||
eq(loginPageOrg.loginPageId, loginPageId)
|
||||
)
|
||||
);
|
||||
|
||||
if (!orgLink) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"Login page not found for this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (updateData.domainId) {
|
||||
const domainId = updateData.domainId;
|
||||
|
||||
// Validate domain and construct full domain
|
||||
const domainResult = await validateAndConstructDomain(
|
||||
domainId,
|
||||
orgId,
|
||||
updateData.subdomain
|
||||
);
|
||||
|
||||
if (!domainResult.success) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, domainResult.error)
|
||||
);
|
||||
}
|
||||
|
||||
const { fullDomain, subdomain: finalSubdomain } = domainResult;
|
||||
|
||||
logger.debug(`Full domain: ${fullDomain}`);
|
||||
|
||||
if (fullDomain) {
|
||||
const [existingDomain] = await db
|
||||
.select()
|
||||
.from(resources)
|
||||
.where(eq(resources.fullDomain, fullDomain));
|
||||
|
||||
if (existingDomain) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
"Resource with that domain already exists"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [existingLoginPage] = await db
|
||||
.select()
|
||||
.from(loginPage)
|
||||
.where(eq(loginPage.fullDomain, fullDomain));
|
||||
|
||||
if (
|
||||
existingLoginPage &&
|
||||
existingLoginPage.loginPageId !== loginPageId
|
||||
) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
"Login page with that domain already exists"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// update the full domain if it has changed
|
||||
if (fullDomain && fullDomain !== existingLoginPage?.fullDomain) {
|
||||
await db
|
||||
.update(loginPage)
|
||||
.set({ fullDomain })
|
||||
.where(eq(loginPage.loginPageId, loginPageId));
|
||||
}
|
||||
|
||||
await createCertificate(domainId, fullDomain, db);
|
||||
}
|
||||
|
||||
updateData.subdomain = finalSubdomain;
|
||||
}
|
||||
|
||||
const updatedLoginPage = await db
|
||||
.update(loginPage)
|
||||
.set({ ...updateData })
|
||||
.where(eq(loginPage.loginPageId, loginPageId))
|
||||
.returning();
|
||||
|
||||
if (updatedLoginPage.length === 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Login page with ID ${loginPageId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return response<LoginPage>(res, {
|
||||
data: updatedLoginPage[0],
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Login page created successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
14
server/private/routers/org/index.ts
Normal file
14
server/private/routers/org/index.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./sendUsageNotifications";
|
||||
249
server/private/routers/org/sendUsageNotifications.ts
Normal file
249
server/private/routers/org/sendUsageNotifications.ts
Normal file
@@ -0,0 +1,249 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { userOrgs, users, roles, orgs } from "@server/db";
|
||||
import { eq, and, or } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { sendEmail } from "@server/emails";
|
||||
import NotifyUsageLimitApproaching from "@server/emails/templates/NotifyUsageLimitApproaching";
|
||||
import NotifyUsageLimitReached from "@server/emails/templates/NotifyUsageLimitReached";
|
||||
import config from "@server/lib/config";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const sendUsageNotificationParamsSchema = z.object({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const sendUsageNotificationBodySchema = z.object({
|
||||
notificationType: z.enum(["approaching_70", "approaching_90", "reached"]),
|
||||
limitName: z.string(),
|
||||
currentUsage: z.number(),
|
||||
usageLimit: z.number(),
|
||||
});
|
||||
|
||||
type SendUsageNotificationRequest = z.infer<typeof sendUsageNotificationBodySchema>;
|
||||
|
||||
export type SendUsageNotificationResponse = {
|
||||
success: boolean;
|
||||
emailsSent: number;
|
||||
adminEmails: string[];
|
||||
};
|
||||
|
||||
// WE SHOULD NOT REGISTER THE PATH IN SAAS
|
||||
// registry.registerPath({
|
||||
// method: "post",
|
||||
// path: "/org/{orgId}/send-usage-notification",
|
||||
// description: "Send usage limit notification emails to all organization admins.",
|
||||
// tags: [OpenAPITags.Org],
|
||||
// request: {
|
||||
// params: sendUsageNotificationParamsSchema,
|
||||
// body: {
|
||||
// content: {
|
||||
// "application/json": {
|
||||
// schema: sendUsageNotificationBodySchema
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// responses: {
|
||||
// 200: {
|
||||
// description: "Usage notifications sent successfully",
|
||||
// content: {
|
||||
// "application/json": {
|
||||
// schema: z.object({
|
||||
// success: z.boolean(),
|
||||
// emailsSent: z.number(),
|
||||
// adminEmails: z.array(z.string())
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
|
||||
async function getOrgAdmins(orgId: string) {
|
||||
// Get all users in the organization who are either:
|
||||
// 1. Organization owners (isOwner = true)
|
||||
// 2. Have admin roles (role.isAdmin = true)
|
||||
const admins = await db
|
||||
.select({
|
||||
userId: users.userId,
|
||||
email: users.email,
|
||||
name: users.name,
|
||||
isOwner: userOrgs.isOwner,
|
||||
roleName: roles.name,
|
||||
isAdminRole: roles.isAdmin
|
||||
})
|
||||
.from(userOrgs)
|
||||
.innerJoin(users, eq(userOrgs.userId, users.userId))
|
||||
.leftJoin(roles, eq(userOrgs.roleId, roles.roleId))
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.orgId, orgId),
|
||||
or(
|
||||
eq(userOrgs.isOwner, true),
|
||||
eq(roles.isAdmin, true)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
// Filter to only include users with verified emails
|
||||
const orgAdmins = admins.filter(admin =>
|
||||
admin.email &&
|
||||
admin.email.length > 0
|
||||
);
|
||||
|
||||
return orgAdmins;
|
||||
}
|
||||
|
||||
export async function sendUsageNotification(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = sendUsageNotificationParamsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const parsedBody = sendUsageNotificationBodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
const {
|
||||
notificationType,
|
||||
limitName,
|
||||
currentUsage,
|
||||
usageLimit,
|
||||
} = parsedBody.data;
|
||||
|
||||
// Verify organization exists
|
||||
const org = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
if (org.length === 0) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
|
||||
);
|
||||
}
|
||||
|
||||
// Get all admin users for this organization
|
||||
const orgAdmins = await getOrgAdmins(orgId);
|
||||
|
||||
if (orgAdmins.length === 0) {
|
||||
logger.warn(`No admin users found for organization ${orgId}`);
|
||||
return response<SendUsageNotificationResponse>(res, {
|
||||
data: {
|
||||
success: true,
|
||||
emailsSent: 0,
|
||||
adminEmails: []
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "No admin users found to notify",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
// Default billing link if not provided
|
||||
const defaultBillingLink = `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing`;
|
||||
|
||||
let emailsSent = 0;
|
||||
const adminEmails: string[] = [];
|
||||
|
||||
// Send emails to all admin users
|
||||
for (const admin of orgAdmins) {
|
||||
if (!admin.email) continue;
|
||||
|
||||
try {
|
||||
let template;
|
||||
let subject;
|
||||
|
||||
if (notificationType === "approaching_70" || notificationType === "approaching_90") {
|
||||
template = NotifyUsageLimitApproaching({
|
||||
email: admin.email,
|
||||
limitName,
|
||||
currentUsage,
|
||||
usageLimit,
|
||||
billingLink: defaultBillingLink
|
||||
});
|
||||
subject = `Usage limit warning for ${limitName}`;
|
||||
} else {
|
||||
template = NotifyUsageLimitReached({
|
||||
email: admin.email,
|
||||
limitName,
|
||||
currentUsage,
|
||||
usageLimit,
|
||||
billingLink: defaultBillingLink
|
||||
});
|
||||
subject = `URGENT: Usage limit reached for ${limitName}`;
|
||||
}
|
||||
|
||||
await sendEmail(template, {
|
||||
to: admin.email,
|
||||
from: config.getNoReplyEmail(),
|
||||
subject
|
||||
});
|
||||
|
||||
emailsSent++;
|
||||
adminEmails.push(admin.email);
|
||||
|
||||
logger.info(`Usage notification sent to admin ${admin.email} for org ${orgId}`);
|
||||
} catch (emailError) {
|
||||
logger.error(`Failed to send usage notification to ${admin.email}:`, emailError);
|
||||
// Continue with other admins even if one fails
|
||||
}
|
||||
}
|
||||
|
||||
return response<SendUsageNotificationResponse>(res, {
|
||||
data: {
|
||||
success: true,
|
||||
emailsSent,
|
||||
adminEmails
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: `Usage notifications sent to ${emailsSent} administrators`,
|
||||
status: HttpCode.OK
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
logger.error("Error sending usage notifications:", error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to send usage notifications")
|
||||
);
|
||||
}
|
||||
}
|
||||
185
server/private/routers/orgIdp/createOrgOidcIdp.ts
Normal file
185
server/private/routers/orgIdp/createOrgOidcIdp.ts
Normal file
@@ -0,0 +1,185 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { idp, idpOidcConfig, idpOrg, orgs } from "@server/db";
|
||||
import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
|
||||
import { encrypt } from "@server/lib/crypto";
|
||||
import config from "@server/lib/config";
|
||||
import { build } from "@server/build";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
|
||||
const paramsSchema = z.object({ orgId: z.string().nonempty() }).strict();
|
||||
|
||||
const bodySchema = z
|
||||
.object({
|
||||
name: z.string().nonempty(),
|
||||
clientId: z.string().nonempty(),
|
||||
clientSecret: z.string().nonempty(),
|
||||
authUrl: z.string().url(),
|
||||
tokenUrl: z.string().url(),
|
||||
identifierPath: z.string().nonempty(),
|
||||
emailPath: z.string().optional(),
|
||||
namePath: z.string().optional(),
|
||||
scopes: z.string().nonempty(),
|
||||
autoProvision: z.boolean().optional(),
|
||||
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
|
||||
roleMapping: z.string().optional()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type CreateOrgIdpResponse = {
|
||||
idpId: number;
|
||||
redirectUrl: string;
|
||||
};
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "put",
|
||||
// path: "/idp/oidc",
|
||||
// description: "Create an OIDC IdP.",
|
||||
// tags: [OpenAPITags.Idp],
|
||||
// request: {
|
||||
// body: {
|
||||
// content: {
|
||||
// "application/json": {
|
||||
// schema: bodySchema
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export async function createOrgOidcIdp(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const {
|
||||
clientId,
|
||||
clientSecret,
|
||||
authUrl,
|
||||
tokenUrl,
|
||||
scopes,
|
||||
identifierPath,
|
||||
emailPath,
|
||||
namePath,
|
||||
name,
|
||||
autoProvision,
|
||||
variant,
|
||||
roleMapping
|
||||
} = parsedBody.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier, active } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const key = config.getRawConfig().server.secret!;
|
||||
|
||||
const encryptedSecret = encrypt(clientSecret, key);
|
||||
const encryptedClientId = encrypt(clientId, key);
|
||||
|
||||
let idpId: number | undefined;
|
||||
await db.transaction(async (trx) => {
|
||||
const [idpRes] = await trx
|
||||
.insert(idp)
|
||||
.values({
|
||||
name,
|
||||
autoProvision,
|
||||
type: "oidc"
|
||||
})
|
||||
.returning();
|
||||
|
||||
idpId = idpRes.idpId;
|
||||
|
||||
await trx.insert(idpOidcConfig).values({
|
||||
idpId: idpRes.idpId,
|
||||
clientId: encryptedClientId,
|
||||
clientSecret: encryptedSecret,
|
||||
authUrl,
|
||||
tokenUrl,
|
||||
scopes,
|
||||
identifierPath,
|
||||
emailPath,
|
||||
namePath,
|
||||
variant
|
||||
});
|
||||
|
||||
await trx.insert(idpOrg).values({
|
||||
idpId: idpRes.idpId,
|
||||
orgId: orgId,
|
||||
roleMapping: roleMapping || null,
|
||||
orgMapping: `'${orgId}'`
|
||||
});
|
||||
});
|
||||
|
||||
const redirectUrl = await generateOidcRedirectUrl(idpId as number, orgId);
|
||||
|
||||
return response<CreateOrgIdpResponse>(res, {
|
||||
data: {
|
||||
idpId: idpId as number,
|
||||
redirectUrl
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Org Idp created successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
108
server/private/routers/orgIdp/deleteOrgIdp.ts
Normal file
108
server/private/routers/orgIdp/deleteOrgIdp.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { idp, idpOidcConfig, idpOrg } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
orgId: z.string().optional(), // Optional; used with org idp in saas
|
||||
idpId: z.coerce.number()
|
||||
})
|
||||
.strict();
|
||||
|
||||
registry.registerPath({
|
||||
method: "delete",
|
||||
path: "/idp/{idpId}",
|
||||
description: "Delete IDP.",
|
||||
tags: [OpenAPITags.Idp],
|
||||
request: {
|
||||
params: paramsSchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function deleteOrgIdp(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { idpId } = parsedParams.data;
|
||||
|
||||
// Check if IDP exists
|
||||
const [existingIdp] = await db
|
||||
.select()
|
||||
.from(idp)
|
||||
.where(eq(idp.idpId, idpId));
|
||||
|
||||
if (!existingIdp) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"IdP not found"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Delete the IDP and its related records in a transaction
|
||||
await db.transaction(async (trx) => {
|
||||
// Delete OIDC config if it exists
|
||||
await trx
|
||||
.delete(idpOidcConfig)
|
||||
.where(eq(idpOidcConfig.idpId, idpId));
|
||||
|
||||
// Delete IDP-org mappings
|
||||
await trx
|
||||
.delete(idpOrg)
|
||||
.where(eq(idpOrg.idpId, idpId));
|
||||
|
||||
// Delete the IDP itself
|
||||
await trx
|
||||
.delete(idp)
|
||||
.where(eq(idp.idpId, idpId));
|
||||
});
|
||||
|
||||
return response<null>(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "IdP deleted successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
117
server/private/routers/orgIdp/getOrgIdp.ts
Normal file
117
server/private/routers/orgIdp/getOrgIdp.ts
Normal file
@@ -0,0 +1,117 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, idpOrg, loginPage, loginPageOrg } from "@server/db";
|
||||
import { idp, idpOidcConfig } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import config from "@server/lib/config";
|
||||
import { decrypt } from "@server/lib/crypto";
|
||||
import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
orgId: z.string().nonempty(),
|
||||
idpId: z.coerce.number()
|
||||
})
|
||||
.strict();
|
||||
|
||||
async function query(idpId: number, orgId: string) {
|
||||
const [res] = await db
|
||||
.select()
|
||||
.from(idp)
|
||||
.where(eq(idp.idpId, idpId))
|
||||
.leftJoin(idpOidcConfig, eq(idpOidcConfig.idpId, idp.idpId))
|
||||
.leftJoin(
|
||||
idpOrg,
|
||||
and(eq(idpOrg.idpId, idp.idpId), eq(idpOrg.orgId, orgId))
|
||||
)
|
||||
.limit(1);
|
||||
return res;
|
||||
}
|
||||
|
||||
export type GetOrgIdpResponse = NonNullable<
|
||||
Awaited<ReturnType<typeof query>>
|
||||
> & { redirectUrl: string };
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "get",
|
||||
// path: "/idp/{idpId}",
|
||||
// description: "Get an IDP by its IDP ID.",
|
||||
// tags: [OpenAPITags.Idp],
|
||||
// request: {
|
||||
// params: paramsSchema
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export async function getOrgIdp(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { idpId, orgId } = parsedParams.data;
|
||||
|
||||
const idpRes = await query(idpId, orgId);
|
||||
|
||||
if (!idpRes) {
|
||||
return next(createHttpError(HttpCode.NOT_FOUND, "Idp not found"));
|
||||
}
|
||||
|
||||
const key = config.getRawConfig().server.secret!;
|
||||
|
||||
if (idpRes.idp.type === "oidc") {
|
||||
const clientSecret = idpRes.idpOidcConfig!.clientSecret;
|
||||
const clientId = idpRes.idpOidcConfig!.clientId;
|
||||
|
||||
idpRes.idpOidcConfig!.clientSecret = decrypt(clientSecret, key);
|
||||
idpRes.idpOidcConfig!.clientId = decrypt(clientId, key);
|
||||
}
|
||||
|
||||
const redirectUrl = await generateOidcRedirectUrl(idpRes.idp.idpId, orgId);
|
||||
|
||||
return response<GetOrgIdpResponse>(res, {
|
||||
data: {
|
||||
...idpRes,
|
||||
redirectUrl
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Org Idp retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
18
server/private/routers/orgIdp/index.ts
Normal file
18
server/private/routers/orgIdp/index.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./createOrgOidcIdp";
|
||||
export * from "./getOrgIdp";
|
||||
export * from "./listOrgIdps";
|
||||
export * from "./updateOrgOidcIdp";
|
||||
export * from "./deleteOrgIdp";
|
||||
142
server/private/routers/orgIdp/listOrgIdps.ts
Normal file
142
server/private/routers/orgIdp/listOrgIdps.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, idpOidcConfig } from "@server/db";
|
||||
import { idp, idpOrg } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import { eq, sql } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const querySchema = z
|
||||
.object({
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().nonnegative()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().nonnegative())
|
||||
})
|
||||
.strict();
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
orgId: z.string().nonempty()
|
||||
})
|
||||
.strict();
|
||||
|
||||
async function query(orgId: string, limit: number, offset: number) {
|
||||
const res = await db
|
||||
.select({
|
||||
idpId: idp.idpId,
|
||||
orgId: idpOrg.orgId,
|
||||
name: idp.name,
|
||||
type: idp.type,
|
||||
variant: idpOidcConfig.variant
|
||||
})
|
||||
.from(idpOrg)
|
||||
.where(eq(idpOrg.orgId, orgId))
|
||||
.innerJoin(idp, eq(idp.idpId, idpOrg.idpId))
|
||||
.innerJoin(idpOidcConfig, eq(idpOidcConfig.idpId, idpOrg.idpId))
|
||||
.orderBy(sql`idp.name DESC`)
|
||||
.limit(limit)
|
||||
.offset(offset);
|
||||
return res;
|
||||
}
|
||||
|
||||
export type ListOrgIdpsResponse = {
|
||||
idps: Awaited<ReturnType<typeof query>>;
|
||||
pagination: {
|
||||
total: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
};
|
||||
};
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "get",
|
||||
// path: "/idp",
|
||||
// description: "List all IDP in the system.",
|
||||
// tags: [OpenAPITags.Idp],
|
||||
// request: {
|
||||
// query: querySchema
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export async function listOrgIdps(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
const parsedQuery = querySchema.safeParse(req.query);
|
||||
if (!parsedQuery.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedQuery.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
const { limit, offset } = parsedQuery.data;
|
||||
|
||||
const list = await query(orgId, limit, offset);
|
||||
|
||||
const [{ count }] = await db
|
||||
.select({ count: sql<number>`count(*)` })
|
||||
.from(idp);
|
||||
|
||||
return response<ListOrgIdpsResponse>(res, {
|
||||
data: {
|
||||
idps: list,
|
||||
pagination: {
|
||||
total: count,
|
||||
limit,
|
||||
offset
|
||||
}
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Org Idps retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
236
server/private/routers/orgIdp/updateOrgOidcIdp.ts
Normal file
236
server/private/routers/orgIdp/updateOrgOidcIdp.ts
Normal file
@@ -0,0 +1,236 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, idpOrg } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { idp, idpOidcConfig } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { encrypt } from "@server/lib/crypto";
|
||||
import config from "@server/lib/config";
|
||||
import { build } from "@server/build";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
orgId: z.string().nonempty(),
|
||||
idpId: z.coerce.number()
|
||||
})
|
||||
.strict();
|
||||
|
||||
const bodySchema = z
|
||||
.object({
|
||||
name: z.string().optional(),
|
||||
clientId: z.string().optional(),
|
||||
clientSecret: z.string().optional(),
|
||||
authUrl: z.string().optional(),
|
||||
tokenUrl: z.string().optional(),
|
||||
identifierPath: z.string().optional(),
|
||||
emailPath: z.string().optional(),
|
||||
namePath: z.string().optional(),
|
||||
scopes: z.string().optional(),
|
||||
autoProvision: z.boolean().optional(),
|
||||
roleMapping: z.string().optional()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type UpdateOrgIdpResponse = {
|
||||
idpId: number;
|
||||
};
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "post",
|
||||
// path: "/idp/{idpId}/oidc",
|
||||
// description: "Update an OIDC IdP.",
|
||||
// tags: [OpenAPITags.Idp],
|
||||
// request: {
|
||||
// params: paramsSchema,
|
||||
// body: {
|
||||
// content: {
|
||||
// "application/json": {
|
||||
// schema: bodySchema
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export async function updateOrgOidcIdp(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { idpId, orgId } = parsedParams.data;
|
||||
const {
|
||||
clientId,
|
||||
clientSecret,
|
||||
authUrl,
|
||||
tokenUrl,
|
||||
scopes,
|
||||
identifierPath,
|
||||
emailPath,
|
||||
namePath,
|
||||
name,
|
||||
autoProvision,
|
||||
roleMapping
|
||||
} = parsedBody.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier, active } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if IDP exists and is of type OIDC
|
||||
const [existingIdp] = await db
|
||||
.select()
|
||||
.from(idp)
|
||||
.where(eq(idp.idpId, idpId));
|
||||
|
||||
if (!existingIdp) {
|
||||
return next(createHttpError(HttpCode.NOT_FOUND, "IdP not found"));
|
||||
}
|
||||
|
||||
const [existingIdpOrg] = await db
|
||||
.select()
|
||||
.from(idpOrg)
|
||||
.where(and(eq(idpOrg.orgId, orgId), eq(idpOrg.idpId, idpId)));
|
||||
|
||||
if (!existingIdpOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"IdP not found for this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (existingIdp.type !== "oidc") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"IdP is not an OIDC provider"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const key = config.getRawConfig().server.secret!;
|
||||
const encryptedSecret = clientSecret
|
||||
? encrypt(clientSecret, key)
|
||||
: undefined;
|
||||
const encryptedClientId = clientId ? encrypt(clientId, key) : undefined;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
const idpData = {
|
||||
name,
|
||||
autoProvision
|
||||
};
|
||||
|
||||
// only update if at least one key is not undefined
|
||||
let keysToUpdate = Object.keys(idpData).filter(
|
||||
(key) => idpData[key as keyof typeof idpData] !== undefined
|
||||
);
|
||||
|
||||
if (keysToUpdate.length > 0) {
|
||||
await trx.update(idp).set(idpData).where(eq(idp.idpId, idpId));
|
||||
}
|
||||
|
||||
const configData = {
|
||||
clientId: encryptedClientId,
|
||||
clientSecret: encryptedSecret,
|
||||
authUrl,
|
||||
tokenUrl,
|
||||
scopes,
|
||||
identifierPath,
|
||||
emailPath,
|
||||
namePath
|
||||
};
|
||||
|
||||
keysToUpdate = Object.keys(configData).filter(
|
||||
(key) =>
|
||||
configData[key as keyof typeof configData] !== undefined
|
||||
);
|
||||
|
||||
if (keysToUpdate.length > 0) {
|
||||
// Update OIDC config
|
||||
await trx
|
||||
.update(idpOidcConfig)
|
||||
.set(configData)
|
||||
.where(eq(idpOidcConfig.idpId, idpId));
|
||||
}
|
||||
|
||||
if (roleMapping !== undefined) {
|
||||
// Update IdP-org policy
|
||||
await trx
|
||||
.update(idpOrg)
|
||||
.set({
|
||||
roleMapping
|
||||
})
|
||||
.where(
|
||||
and(eq(idpOrg.idpId, idpId), eq(idpOrg.orgId, orgId))
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
return response<UpdateOrgIdpResponse>(res, {
|
||||
data: {
|
||||
idpId
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Org IdP updated successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
278
server/private/routers/remoteExitNode/createRemoteExitNode.ts
Normal file
278
server/private/routers/remoteExitNode/createRemoteExitNode.ts
Normal file
@@ -0,0 +1,278 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg } from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { z } from "zod";
|
||||
import { remoteExitNodes } from "@server/db";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
import { SqliteError } from "better-sqlite3";
|
||||
import moment from "moment";
|
||||
import { generateSessionToken } from "@server/auth/sessions/app";
|
||||
import { createRemoteExitNodeSession } from "#private/auth/sessions/remoteExitNode";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { hashPassword, verifyPassword } from "@server/auth/password";
|
||||
import logger from "@server/logger";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { getNextAvailableSubnet } from "@server/lib/exitNodes";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
|
||||
export const paramsSchema = z.object({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export type CreateRemoteExitNodeResponse = {
|
||||
token: string;
|
||||
remoteExitNodeId: string;
|
||||
secret: string;
|
||||
};
|
||||
|
||||
const bodySchema = z
|
||||
.object({
|
||||
remoteExitNodeId: z.string().length(15),
|
||||
secret: z.string().length(48)
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type CreateRemoteExitNodeBody = z.infer<typeof bodySchema>;
|
||||
|
||||
export async function createRemoteExitNode(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { remoteExitNodeId, secret } = parsedBody.data;
|
||||
|
||||
if (req.user && !req.userOrgRoleId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
|
||||
);
|
||||
}
|
||||
|
||||
const usage = await usageService.getUsage(
|
||||
orgId,
|
||||
FeatureId.REMOTE_EXIT_NODES
|
||||
);
|
||||
if (!usage) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"No usage data found for this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
const rejectRemoteExitNodes = await usageService.checkLimitSet(
|
||||
orgId,
|
||||
false,
|
||||
FeatureId.REMOTE_EXIT_NODES,
|
||||
{
|
||||
...usage,
|
||||
instantaneousValue: (usage.instantaneousValue || 0) + 1
|
||||
} // We need to add one to know if we are violating the limit
|
||||
);
|
||||
if (rejectRemoteExitNodes) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Remote exit node limit exceeded. Please upgrade your plan or contact us at support@fossorial.io"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const secretHash = await hashPassword(secret);
|
||||
// const address = await getNextAvailableSubnet();
|
||||
const address = "100.89.140.1/24"; // FOR NOW LETS HARDCODE THESE ADDRESSES
|
||||
|
||||
const [existingRemoteExitNode] = await db
|
||||
.select()
|
||||
.from(remoteExitNodes)
|
||||
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId));
|
||||
|
||||
if (existingRemoteExitNode) {
|
||||
// validate the secret
|
||||
|
||||
const validSecret = await verifyPassword(
|
||||
secret,
|
||||
existingRemoteExitNode.secretHash
|
||||
);
|
||||
if (!validSecret) {
|
||||
logger.info(
|
||||
`Failed secret validation for remote exit node: ${remoteExitNodeId}`
|
||||
);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
"Invalid secret for remote exit node"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let existingExitNode: ExitNode | null = null;
|
||||
if (existingRemoteExitNode?.exitNodeId) {
|
||||
const [res] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(
|
||||
eq(exitNodes.exitNodeId, existingRemoteExitNode.exitNodeId)
|
||||
);
|
||||
existingExitNode = res;
|
||||
}
|
||||
|
||||
let existingExitNodeOrg: ExitNodeOrg | null = null;
|
||||
if (existingRemoteExitNode?.exitNodeId) {
|
||||
const [res] = await db
|
||||
.select()
|
||||
.from(exitNodeOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(
|
||||
exitNodeOrgs.exitNodeId,
|
||||
existingRemoteExitNode.exitNodeId
|
||||
),
|
||||
eq(exitNodeOrgs.orgId, orgId)
|
||||
)
|
||||
);
|
||||
existingExitNodeOrg = res;
|
||||
}
|
||||
|
||||
if (existingExitNodeOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Remote exit node already exists in this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let numExitNodeOrgs: ExitNodeOrg[] | undefined;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
if (!existingExitNode) {
|
||||
const [res] = await trx
|
||||
.insert(exitNodes)
|
||||
.values({
|
||||
name: remoteExitNodeId,
|
||||
address,
|
||||
endpoint: "",
|
||||
publicKey: "",
|
||||
listenPort: 0,
|
||||
online: false,
|
||||
type: "remoteExitNode"
|
||||
})
|
||||
.returning();
|
||||
existingExitNode = res;
|
||||
}
|
||||
|
||||
if (!existingRemoteExitNode) {
|
||||
await trx.insert(remoteExitNodes).values({
|
||||
remoteExitNodeId: remoteExitNodeId,
|
||||
secretHash,
|
||||
dateCreated: moment().toISOString(),
|
||||
exitNodeId: existingExitNode.exitNodeId
|
||||
});
|
||||
} else {
|
||||
// update the existing remote exit node
|
||||
await trx
|
||||
.update(remoteExitNodes)
|
||||
.set({
|
||||
exitNodeId: existingExitNode.exitNodeId
|
||||
})
|
||||
.where(
|
||||
eq(
|
||||
remoteExitNodes.remoteExitNodeId,
|
||||
existingRemoteExitNode.remoteExitNodeId
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!existingExitNodeOrg) {
|
||||
await trx.insert(exitNodeOrgs).values({
|
||||
exitNodeId: existingExitNode.exitNodeId,
|
||||
orgId: orgId
|
||||
});
|
||||
}
|
||||
|
||||
numExitNodeOrgs = await trx
|
||||
.select()
|
||||
.from(exitNodeOrgs)
|
||||
.where(eq(exitNodeOrgs.orgId, orgId));
|
||||
});
|
||||
|
||||
if (numExitNodeOrgs) {
|
||||
await usageService.updateDaily(
|
||||
orgId,
|
||||
FeatureId.REMOTE_EXIT_NODES,
|
||||
numExitNodeOrgs.length
|
||||
);
|
||||
}
|
||||
|
||||
const token = generateSessionToken();
|
||||
await createRemoteExitNodeSession(token, remoteExitNodeId);
|
||||
|
||||
return response<CreateRemoteExitNodeResponse>(res, {
|
||||
data: {
|
||||
remoteExitNodeId,
|
||||
secret,
|
||||
token
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "RemoteExitNode created successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"A remote exit node with that ID already exists"
|
||||
)
|
||||
);
|
||||
} else {
|
||||
logger.error("Failed to create remoteExitNode", e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to create remoteExitNode"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
131
server/private/routers/remoteExitNode/deleteRemoteExitNode.ts
Normal file
131
server/private/routers/remoteExitNode/deleteRemoteExitNode.ts
Normal file
@@ -0,0 +1,131 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, ExitNodeOrg, exitNodeOrgs, exitNodes } from "@server/db";
|
||||
import { remoteExitNodes } from "@server/db";
|
||||
import { and, count, eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
orgId: z.string().min(1),
|
||||
remoteExitNodeId: z.string().min(1)
|
||||
})
|
||||
.strict();
|
||||
|
||||
export async function deleteRemoteExitNode(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId, remoteExitNodeId } = parsedParams.data;
|
||||
|
||||
const [remoteExitNode] = await db
|
||||
.select()
|
||||
.from(remoteExitNodes)
|
||||
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId));
|
||||
|
||||
if (!remoteExitNode) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Remote exit node with ID ${remoteExitNodeId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!remoteExitNode.exitNodeId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
`Remote exit node with ID ${remoteExitNodeId} does not have an exit node ID`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let numExitNodeOrgs: ExitNodeOrg[] | undefined;
|
||||
await db.transaction(async (trx) => {
|
||||
await trx
|
||||
.delete(exitNodeOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(exitNodeOrgs.orgId, orgId),
|
||||
eq(exitNodeOrgs.exitNodeId, remoteExitNode.exitNodeId!)
|
||||
)
|
||||
);
|
||||
|
||||
const [remainingExitNodeOrgs] = await trx
|
||||
.select({ count: count() })
|
||||
.from(exitNodeOrgs)
|
||||
.where(eq(exitNodeOrgs.exitNodeId, remoteExitNode.exitNodeId!));
|
||||
|
||||
if (remainingExitNodeOrgs.count === 0) {
|
||||
await trx
|
||||
.delete(remoteExitNodes)
|
||||
.where(
|
||||
eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId)
|
||||
);
|
||||
await trx
|
||||
.delete(exitNodes)
|
||||
.where(
|
||||
eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId!)
|
||||
);
|
||||
}
|
||||
|
||||
numExitNodeOrgs = await trx
|
||||
.select()
|
||||
.from(exitNodeOrgs)
|
||||
.where(eq(exitNodeOrgs.orgId, orgId));
|
||||
});
|
||||
|
||||
if (numExitNodeOrgs) {
|
||||
await usageService.updateDaily(
|
||||
orgId,
|
||||
FeatureId.REMOTE_EXIT_NODES,
|
||||
numExitNodeOrgs.length
|
||||
);
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Remote exit node deleted successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
99
server/private/routers/remoteExitNode/getRemoteExitNode.ts
Normal file
99
server/private/routers/remoteExitNode/getRemoteExitNode.ts
Normal file
@@ -0,0 +1,99 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, exitNodes } from "@server/db";
|
||||
import { remoteExitNodes } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
|
||||
const getRemoteExitNodeSchema = z
|
||||
.object({
|
||||
orgId: z.string().min(1),
|
||||
remoteExitNodeId: z.string().min(1)
|
||||
})
|
||||
.strict();
|
||||
|
||||
async function query(remoteExitNodeId: string) {
|
||||
const [remoteExitNode] = await db
|
||||
.select({
|
||||
remoteExitNodeId: remoteExitNodes.remoteExitNodeId,
|
||||
dateCreated: remoteExitNodes.dateCreated,
|
||||
version: remoteExitNodes.version,
|
||||
exitNodeId: remoteExitNodes.exitNodeId,
|
||||
name: exitNodes.name,
|
||||
address: exitNodes.address,
|
||||
endpoint: exitNodes.endpoint,
|
||||
online: exitNodes.online,
|
||||
type: exitNodes.type
|
||||
})
|
||||
.from(remoteExitNodes)
|
||||
.innerJoin(
|
||||
exitNodes,
|
||||
eq(exitNodes.exitNodeId, remoteExitNodes.exitNodeId)
|
||||
)
|
||||
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId))
|
||||
.limit(1);
|
||||
return remoteExitNode;
|
||||
}
|
||||
|
||||
export type GetRemoteExitNodeResponse = Awaited<ReturnType<typeof query>>;
|
||||
|
||||
export async function getRemoteExitNode(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = getRemoteExitNodeSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { remoteExitNodeId } = parsedParams.data;
|
||||
|
||||
const remoteExitNode = await query(remoteExitNodeId);
|
||||
|
||||
if (!remoteExitNode) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Remote exit node with ID ${remoteExitNodeId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return response<GetRemoteExitNodeResponse>(res, {
|
||||
data: remoteExitNode,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Remote exit node retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
130
server/private/routers/remoteExitNode/getRemoteExitNodeToken.ts
Normal file
130
server/private/routers/remoteExitNode/getRemoteExitNodeToken.ts
Normal file
@@ -0,0 +1,130 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { generateSessionToken } from "@server/auth/sessions/app";
|
||||
import { db } from "@server/db";
|
||||
import { remoteExitNodes } from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import response from "@server/lib/response";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import {
|
||||
createRemoteExitNodeSession,
|
||||
validateRemoteExitNodeSessionToken
|
||||
} from "#private/auth/sessions/remoteExitNode";
|
||||
import { verifyPassword } from "@server/auth/password";
|
||||
import logger from "@server/logger";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
export const remoteExitNodeGetTokenBodySchema = z.object({
|
||||
remoteExitNodeId: z.string(),
|
||||
secret: z.string(),
|
||||
token: z.string().optional()
|
||||
});
|
||||
|
||||
export type RemoteExitNodeGetTokenBody = z.infer<typeof remoteExitNodeGetTokenBodySchema>;
|
||||
|
||||
export async function getRemoteExitNodeToken(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
const parsedBody = remoteExitNodeGetTokenBodySchema.safeParse(req.body);
|
||||
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { remoteExitNodeId, secret, token } = parsedBody.data;
|
||||
|
||||
try {
|
||||
if (token) {
|
||||
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
|
||||
if (session) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
`RemoteExitNode session already valid. RemoteExitNode ID: ${remoteExitNodeId}. IP: ${req.ip}.`
|
||||
);
|
||||
}
|
||||
return response<null>(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Token session already valid",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const existingRemoteExitNodeRes = await db
|
||||
.select()
|
||||
.from(remoteExitNodes)
|
||||
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId));
|
||||
if (!existingRemoteExitNodeRes || !existingRemoteExitNodeRes.length) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"No remoteExitNode found with that remoteExitNodeId"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const existingRemoteExitNode = existingRemoteExitNodeRes[0];
|
||||
|
||||
const validSecret = await verifyPassword(
|
||||
secret,
|
||||
existingRemoteExitNode.secretHash
|
||||
);
|
||||
if (!validSecret) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
`RemoteExitNode id or secret is incorrect. RemoteExitNode: ID ${remoteExitNodeId}. IP: ${req.ip}.`
|
||||
);
|
||||
}
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect")
|
||||
);
|
||||
}
|
||||
|
||||
const resToken = generateSessionToken();
|
||||
await createRemoteExitNodeSession(resToken, existingRemoteExitNode.remoteExitNodeId);
|
||||
|
||||
// logger.debug(`Created RemoteExitNode token response: ${JSON.stringify(resToken)}`);
|
||||
|
||||
return response<{ token: string }>(res, {
|
||||
data: {
|
||||
token: resToken
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Token created successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to authenticate remoteExitNode"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { db, exitNodes, sites } from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { clients, RemoteExitNode } from "@server/db";
|
||||
import { eq, lt, isNull, and, or, inArray } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
|
||||
// Track if the offline checker interval is running
|
||||
let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
||||
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
|
||||
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
|
||||
|
||||
/**
|
||||
* Starts the background interval that checks for clients that haven't pinged recently
|
||||
* and marks them as offline
|
||||
*/
|
||||
export const startRemoteExitNodeOfflineChecker = (): void => {
|
||||
if (offlineCheckerInterval) {
|
||||
return; // Already running
|
||||
}
|
||||
|
||||
offlineCheckerInterval = setInterval(async () => {
|
||||
try {
|
||||
const twoMinutesAgo = Math.floor((Date.now() - OFFLINE_THRESHOLD_MS) / 1000);
|
||||
|
||||
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
|
||||
const newlyOfflineNodes = await db
|
||||
.update(exitNodes)
|
||||
.set({ online: false })
|
||||
.where(
|
||||
and(
|
||||
eq(exitNodes.online, true),
|
||||
eq(exitNodes.type, "remoteExitNode"),
|
||||
or(
|
||||
lt(exitNodes.lastPing, twoMinutesAgo),
|
||||
isNull(exitNodes.lastPing)
|
||||
)
|
||||
)
|
||||
).returning();
|
||||
|
||||
|
||||
// Update the sites to offline if they have not pinged either
|
||||
const exitNodeIds = newlyOfflineNodes.map(node => node.exitNodeId);
|
||||
|
||||
const sitesOnNode = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(
|
||||
and(
|
||||
eq(sites.online, true),
|
||||
inArray(sites.exitNodeId, exitNodeIds)
|
||||
)
|
||||
);
|
||||
|
||||
// loop through the sites and process their lastBandwidthUpdate as an iso string and if its more than 1 minute old then mark the site offline
|
||||
for (const site of sitesOnNode) {
|
||||
if (!site.lastBandwidthUpdate) {
|
||||
continue;
|
||||
}
|
||||
const lastBandwidthUpdate = new Date(site.lastBandwidthUpdate);
|
||||
if (Date.now() - lastBandwidthUpdate.getTime() > 60 * 1000) {
|
||||
await db
|
||||
.update(sites)
|
||||
.set({ online: false })
|
||||
.where(eq(sites.siteId, site.siteId));
|
||||
}
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
logger.error("Error in offline checker interval", { error });
|
||||
}
|
||||
}, OFFLINE_CHECK_INTERVAL);
|
||||
|
||||
logger.info("Started offline checker interval");
|
||||
};
|
||||
|
||||
/**
|
||||
* Stops the background interval that checks for offline clients
|
||||
*/
|
||||
export const stopRemoteExitNodeOfflineChecker = (): void => {
|
||||
if (offlineCheckerInterval) {
|
||||
clearInterval(offlineCheckerInterval);
|
||||
offlineCheckerInterval = null;
|
||||
logger.info("Stopped offline checker interval");
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles ping messages from clients and responds with pong
|
||||
*/
|
||||
export const handleRemoteExitNodePingMessage: MessageHandler = async (context) => {
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const remoteExitNode = c as RemoteExitNode;
|
||||
|
||||
if (!remoteExitNode) {
|
||||
logger.debug("RemoteExitNode not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!remoteExitNode.exitNodeId) {
|
||||
logger.debug("RemoteExitNode has no exit node ID!"); // this can happen if the exit node is created but not adopted yet
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Update the exit node's last ping timestamp
|
||||
await db
|
||||
.update(exitNodes)
|
||||
.set({
|
||||
lastPing: Math.floor(Date.now() / 1000),
|
||||
online: true,
|
||||
})
|
||||
.where(eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId));
|
||||
} catch (error) {
|
||||
logger.error("Error handling ping message", { error });
|
||||
}
|
||||
|
||||
return {
|
||||
message: {
|
||||
type: "pong",
|
||||
data: {
|
||||
timestamp: new Date().toISOString(),
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
excludeSender: false
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,49 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { db, RemoteExitNode, remoteExitNodes } from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { eq } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
|
||||
context
|
||||
) => {
|
||||
const { message, client, sendToClient } = context;
|
||||
const remoteExitNode = client as RemoteExitNode;
|
||||
|
||||
logger.debug("Handling register remoteExitNode message!");
|
||||
|
||||
if (!remoteExitNode) {
|
||||
logger.warn("Remote exit node not found");
|
||||
return;
|
||||
}
|
||||
|
||||
const { remoteExitNodeVersion } = message.data;
|
||||
|
||||
if (!remoteExitNodeVersion) {
|
||||
logger.warn("Remote exit node version not found");
|
||||
return;
|
||||
}
|
||||
|
||||
// update the version
|
||||
await db
|
||||
.update(remoteExitNodes)
|
||||
.set({ version: remoteExitNodeVersion })
|
||||
.where(
|
||||
eq(
|
||||
remoteExitNodes.remoteExitNodeId,
|
||||
remoteExitNode.remoteExitNodeId
|
||||
)
|
||||
);
|
||||
};
|
||||
23
server/private/routers/remoteExitNode/index.ts
Normal file
23
server/private/routers/remoteExitNode/index.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./createRemoteExitNode";
|
||||
export * from "./getRemoteExitNode";
|
||||
export * from "./listRemoteExitNodes";
|
||||
export * from "./getRemoteExitNodeToken";
|
||||
export * from "./handleRemoteExitNodeRegisterMessage";
|
||||
export * from "./handleRemoteExitNodePingMessage";
|
||||
export * from "./deleteRemoteExitNode";
|
||||
export * from "./listRemoteExitNodes";
|
||||
export * from "./pickRemoteExitNodeDefaults";
|
||||
export * from "./quickStartRemoteExitNode";
|
||||
147
server/private/routers/remoteExitNode/listRemoteExitNodes.ts
Normal file
147
server/private/routers/remoteExitNode/listRemoteExitNodes.ts
Normal file
@@ -0,0 +1,147 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, exitNodeOrgs, exitNodes } from "@server/db";
|
||||
import { remoteExitNodes } from "@server/db";
|
||||
import { eq, and, count } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
|
||||
const listRemoteExitNodesParamsSchema = z
|
||||
.object({
|
||||
orgId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
const listRemoteExitNodesSchema = z.object({
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().positive()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().nonnegative())
|
||||
});
|
||||
|
||||
function queryRemoteExitNodes(orgId: string) {
|
||||
return db
|
||||
.select({
|
||||
remoteExitNodeId: remoteExitNodes.remoteExitNodeId,
|
||||
dateCreated: remoteExitNodes.dateCreated,
|
||||
version: remoteExitNodes.version,
|
||||
exitNodeId: remoteExitNodes.exitNodeId,
|
||||
name: exitNodes.name,
|
||||
address: exitNodes.address,
|
||||
endpoint: exitNodes.endpoint,
|
||||
online: exitNodes.online,
|
||||
type: exitNodes.type
|
||||
})
|
||||
.from(exitNodeOrgs)
|
||||
.where(eq(exitNodeOrgs.orgId, orgId))
|
||||
.innerJoin(exitNodes, eq(exitNodes.exitNodeId, exitNodeOrgs.exitNodeId))
|
||||
.innerJoin(
|
||||
remoteExitNodes,
|
||||
eq(remoteExitNodes.exitNodeId, exitNodeOrgs.exitNodeId)
|
||||
);
|
||||
}
|
||||
|
||||
export type ListRemoteExitNodesResponse = {
|
||||
remoteExitNodes: Awaited<ReturnType<typeof queryRemoteExitNodes>>;
|
||||
pagination: { total: number; limit: number; offset: number };
|
||||
};
|
||||
|
||||
export async function listRemoteExitNodes(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedQuery = listRemoteExitNodesSchema.safeParse(req.query);
|
||||
if (!parsedQuery.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedQuery.error)
|
||||
)
|
||||
);
|
||||
}
|
||||
const { limit, offset } = parsedQuery.data;
|
||||
|
||||
const parsedParams = listRemoteExitNodesParamsSchema.safeParse(
|
||||
req.params
|
||||
);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error)
|
||||
)
|
||||
);
|
||||
}
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
if (req.user && orgId && orgId !== req.userOrgId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not have access to this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const baseQuery = queryRemoteExitNodes(orgId);
|
||||
|
||||
const countQuery = db
|
||||
.select({ count: count() })
|
||||
.from(remoteExitNodes)
|
||||
.innerJoin(
|
||||
exitNodes,
|
||||
eq(exitNodes.exitNodeId, remoteExitNodes.exitNodeId)
|
||||
)
|
||||
.where(eq(exitNodes.type, "remoteExitNode"));
|
||||
|
||||
const remoteExitNodesList = await baseQuery.limit(limit).offset(offset);
|
||||
const totalCountResult = await countQuery;
|
||||
const totalCount = totalCountResult[0].count;
|
||||
|
||||
return response<ListRemoteExitNodesResponse>(res, {
|
||||
data: {
|
||||
remoteExitNodes: remoteExitNodesList,
|
||||
pagination: {
|
||||
total: totalCount,
|
||||
limit,
|
||||
offset
|
||||
}
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Remote exit nodes retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { generateId } from "@server/auth/sessions/app";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { z } from "zod";
|
||||
|
||||
export type PickRemoteExitNodeDefaultsResponse = {
|
||||
remoteExitNodeId: string;
|
||||
secret: string;
|
||||
};
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
orgId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export async function pickRemoteExitNodeDefaults(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
const remoteExitNodeId = generateId(15);
|
||||
const secret = generateId(48);
|
||||
|
||||
return response<PickRemoteExitNodeDefaultsResponse>(res, {
|
||||
data: {
|
||||
remoteExitNodeId,
|
||||
secret
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Organization retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db, exitNodes, exitNodeOrgs } from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { remoteExitNodes } from "@server/db";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
import { SqliteError } from "better-sqlite3";
|
||||
import moment from "moment";
|
||||
import { generateId } from "@server/auth/sessions/app";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import logger from "@server/logger";
|
||||
import z from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
|
||||
export type QuickStartRemoteExitNodeResponse = {
|
||||
remoteExitNodeId: string;
|
||||
secret: string;
|
||||
};
|
||||
|
||||
const INSTALLER_KEY = "af4e4785-7e09-11f0-b93a-74563c4e2a7e";
|
||||
|
||||
const quickStartRemoteExitNodeBodySchema = z.object({
|
||||
token: z.string()
|
||||
});
|
||||
|
||||
export async function quickStartRemoteExitNode(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedBody = quickStartRemoteExitNodeBodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { token } = parsedBody.data;
|
||||
|
||||
const tokenValidation = validateTokenOnApi(token);
|
||||
if (!tokenValidation.isValid) {
|
||||
logger.info(`Failed token validation: ${tokenValidation.message}`);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
fromError(tokenValidation.message).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const remoteExitNodeId = generateId(15);
|
||||
const secret = generateId(48);
|
||||
const secretHash = await hashPassword(secret);
|
||||
|
||||
await db.insert(remoteExitNodes).values({
|
||||
remoteExitNodeId,
|
||||
secretHash,
|
||||
dateCreated: moment().toISOString()
|
||||
});
|
||||
|
||||
return response<QuickStartRemoteExitNodeResponse>(res, {
|
||||
data: {
|
||||
remoteExitNodeId,
|
||||
secret
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Remote exit node created successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"A remote exit node with that ID already exists"
|
||||
)
|
||||
);
|
||||
} else {
|
||||
logger.error("Failed to create remoteExitNode", e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to create remoteExitNode"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates a token received from the frontend.
|
||||
* @param {string} token The validation token from the request.
|
||||
* @returns {{ isValid: boolean; message: string }} An object indicating if the token is valid.
|
||||
*/
|
||||
const validateTokenOnApi = (
|
||||
token: string
|
||||
): { isValid: boolean; message: string } => {
|
||||
if (!token) {
|
||||
return { isValid: false, message: "Error: No token provided." };
|
||||
}
|
||||
|
||||
try {
|
||||
// 1. Decode the base64 string
|
||||
const decodedB64 = atob(token);
|
||||
|
||||
// 2. Reverse the character code manipulation
|
||||
const deobfuscated = decodedB64
|
||||
.split("")
|
||||
.map((char) => String.fromCharCode(char.charCodeAt(0) - 5)) // Reverse the shift
|
||||
.join("");
|
||||
|
||||
// 3. Split the data to get the original secret and timestamp
|
||||
const parts = deobfuscated.split("|");
|
||||
if (parts.length !== 2) {
|
||||
throw new Error("Invalid token format.");
|
||||
}
|
||||
const receivedKey = parts[0];
|
||||
const tokenTimestamp = parseInt(parts[1], 10);
|
||||
|
||||
// 4. Check if the secret key matches
|
||||
if (receivedKey !== INSTALLER_KEY) {
|
||||
logger.info(`Token key mismatch. Received: ${receivedKey}`);
|
||||
return { isValid: false, message: "Invalid token: Key mismatch." };
|
||||
}
|
||||
|
||||
// 5. Check if the timestamp is recent (e.g., within 30 seconds) to prevent replay attacks
|
||||
const now = Date.now();
|
||||
const timeDifference = now - tokenTimestamp;
|
||||
|
||||
if (timeDifference > 30000) {
|
||||
// 30 seconds
|
||||
return { isValid: false, message: "Invalid token: Expired." };
|
||||
}
|
||||
|
||||
if (timeDifference < 0) {
|
||||
// Timestamp is in the future
|
||||
return {
|
||||
isValid: false,
|
||||
message: "Invalid token: Timestamp is in the future."
|
||||
};
|
||||
}
|
||||
|
||||
// If all checks pass, the token is valid
|
||||
return { isValid: true, message: "Token is valid!" };
|
||||
} catch (error) {
|
||||
// This will catch errors from atob (if not valid base64) or other issues.
|
||||
return {
|
||||
isValid: false,
|
||||
message: `Error: ${(error as Error).message}`
|
||||
};
|
||||
}
|
||||
};
|
||||
14
server/private/routers/ws/index.ts
Normal file
14
server/private/routers/ws/index.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./ws";
|
||||
26
server/private/routers/ws/messageHandlers.ts
Normal file
26
server/private/routers/ws/messageHandlers.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import {
|
||||
handleRemoteExitNodeRegisterMessage,
|
||||
handleRemoteExitNodePingMessage,
|
||||
startRemoteExitNodeOfflineChecker
|
||||
} from "#private/routers/remoteExitNode";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
|
||||
export const messageHandlers: Record<string, MessageHandler> = {
|
||||
"remoteExitNode/register": handleRemoteExitNodeRegisterMessage,
|
||||
"remoteExitNode/ping": handleRemoteExitNodePingMessage
|
||||
};
|
||||
|
||||
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
|
||||
834
server/private/routers/ws/ws.ts
Normal file
834
server/private/routers/ws/ws.ts
Normal file
@@ -0,0 +1,834 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Router, Request, Response } from "express";
|
||||
import { Server as HttpServer } from "http";
|
||||
import { WebSocket, WebSocketServer } from "ws";
|
||||
import { Socket } from "net";
|
||||
import {
|
||||
Newt,
|
||||
newts,
|
||||
NewtSession,
|
||||
olms,
|
||||
Olm,
|
||||
OlmSession,
|
||||
RemoteExitNode,
|
||||
RemoteExitNodeSession,
|
||||
remoteExitNodes
|
||||
} from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { db } from "@server/db";
|
||||
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
|
||||
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
|
||||
import logger from "@server/logger";
|
||||
import redisManager from "#private/lib/redis";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { validateRemoteExitNodeSessionToken } from "#private/auth/sessions/remoteExitNode";
|
||||
import { rateLimitService } from "#private/lib/rateLimit";
|
||||
import { messageHandlers } from "@server/routers/ws/messageHandlers";
|
||||
import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers";
|
||||
import { AuthenticatedWebSocket, ClientType, WSMessage, TokenPayload, WebSocketRequest, RedisMessage } from "@server/routers/ws";
|
||||
|
||||
// Merge public and private message handlers
|
||||
Object.assign(messageHandlers, privateMessageHandlers);
|
||||
|
||||
const MAX_PENDING_MESSAGES = 50; // Maximum messages to queue during connection setup
|
||||
|
||||
// Helper function to process a single message
|
||||
const processMessage = async (
|
||||
ws: AuthenticatedWebSocket,
|
||||
data: Buffer,
|
||||
clientId: string,
|
||||
clientType: ClientType
|
||||
): Promise<void> => {
|
||||
try {
|
||||
const message: WSMessage = JSON.parse(data.toString());
|
||||
|
||||
logger.debug(
|
||||
`Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
|
||||
);
|
||||
|
||||
if (!message.type || typeof message.type !== "string") {
|
||||
throw new Error("Invalid message format: missing or invalid type");
|
||||
}
|
||||
|
||||
// Check rate limiting with message type awareness
|
||||
const rateLimitResult = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
message.type, // Pass message type for granular limiting
|
||||
100, // max requests per window
|
||||
20, // max requests per message type per window
|
||||
60 * 1000 // window in milliseconds
|
||||
);
|
||||
if (rateLimitResult.isLimited) {
|
||||
const reason =
|
||||
rateLimitResult.reason === "global"
|
||||
? "too many messages"
|
||||
: `too many '${message.type}' messages`;
|
||||
logger.debug(
|
||||
`Rate limit exceeded for ${clientType.toUpperCase()} ID: ${clientId} - ${reason}, ignoring message`
|
||||
);
|
||||
|
||||
// Send rate limit error to client
|
||||
// ws.send(JSON.stringify({
|
||||
// type: "rate_limit_error",
|
||||
// data: {
|
||||
// message: `Rate limit exceeded: ${reason}`,
|
||||
// messageType: message.type,
|
||||
// reason: rateLimitResult.reason
|
||||
// }
|
||||
// }));
|
||||
return;
|
||||
}
|
||||
|
||||
const handler = messageHandlers[message.type];
|
||||
if (!handler) {
|
||||
throw new Error(`Unsupported message type: ${message.type}`);
|
||||
}
|
||||
|
||||
const response = await handler({
|
||||
message,
|
||||
senderWs: ws,
|
||||
client: ws.client,
|
||||
clientType: ws.clientType!,
|
||||
sendToClient,
|
||||
broadcastToAllExcept,
|
||||
connectedClients
|
||||
});
|
||||
|
||||
if (response) {
|
||||
if (response.broadcast) {
|
||||
await broadcastToAllExcept(
|
||||
response.message,
|
||||
response.excludeSender ? clientId : undefined
|
||||
);
|
||||
} else if (response.targetClientId) {
|
||||
await sendToClient(response.targetClientId, response.message);
|
||||
} else {
|
||||
ws.send(JSON.stringify(response.message));
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Message handling error:", error);
|
||||
// ws.send(JSON.stringify({
|
||||
// type: "error",
|
||||
// data: {
|
||||
// message: error instanceof Error ? error.message : "Unknown error occurred",
|
||||
// originalMessage: data.toString()
|
||||
// }
|
||||
// }));
|
||||
}
|
||||
};
|
||||
|
||||
// Helper function to process pending messages
|
||||
const processPendingMessages = async (
|
||||
ws: AuthenticatedWebSocket,
|
||||
clientId: string,
|
||||
clientType: ClientType
|
||||
): Promise<void> => {
|
||||
if (!ws.pendingMessages || ws.pendingMessages.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Processing ${ws.pendingMessages.length} pending messages for ${clientType.toUpperCase()} ID: ${clientId}`
|
||||
);
|
||||
|
||||
const jobs = [];
|
||||
for (const messageData of ws.pendingMessages) {
|
||||
jobs.push(processMessage(ws, messageData, clientId, clientType));
|
||||
}
|
||||
|
||||
await Promise.all(jobs);
|
||||
|
||||
ws.pendingMessages = []; // Clear pending messages to prevent reprocessing
|
||||
};
|
||||
|
||||
const router: Router = Router();
|
||||
const wss: WebSocketServer = new WebSocketServer({ noServer: true });
|
||||
|
||||
// Generate unique node ID for this instance
|
||||
const NODE_ID = uuidv4();
|
||||
const REDIS_CHANNEL = "websocket_messages";
|
||||
|
||||
// Client tracking map (local to this node)
|
||||
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
|
||||
|
||||
// Recovery tracking
|
||||
let isRedisRecoveryInProgress = false;
|
||||
|
||||
// Helper to get map key
|
||||
const getClientMapKey = (clientId: string) => clientId;
|
||||
|
||||
// Redis keys (generalized)
|
||||
const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`;
|
||||
const getNodeConnectionsKey = (nodeId: string, clientId: string) =>
|
||||
`ws:node:${nodeId}:${clientId}`;
|
||||
|
||||
// Initialize Redis subscription for cross-node messaging
|
||||
const initializeRedisSubscription = async (): Promise<void> => {
|
||||
if (!redisManager.isRedisEnabled()) return;
|
||||
|
||||
await redisManager.subscribe(
|
||||
REDIS_CHANNEL,
|
||||
async (channel: string, message: string) => {
|
||||
try {
|
||||
const redisMessage: RedisMessage = JSON.parse(message);
|
||||
|
||||
// Ignore messages from this node
|
||||
if (redisMessage.fromNodeId === NODE_ID) return;
|
||||
|
||||
if (
|
||||
redisMessage.type === "direct" &&
|
||||
redisMessage.targetClientId
|
||||
) {
|
||||
// Send to specific client on this node
|
||||
await sendToClientLocal(
|
||||
redisMessage.targetClientId,
|
||||
redisMessage.message
|
||||
);
|
||||
} else if (redisMessage.type === "broadcast") {
|
||||
// Broadcast to all clients on this node except excluded
|
||||
await broadcastToAllExceptLocal(
|
||||
redisMessage.message,
|
||||
redisMessage.excludeClientId
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Error processing Redis message:", error);
|
||||
}
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
// Simple self-healing recovery function
|
||||
// Each node is responsible for restoring its own connection state to Redis
|
||||
// This approach is more efficient than cross-node coordination because:
|
||||
// 1. Each node knows its own connections (source of truth)
|
||||
// 2. No network overhead from broadcasting state between nodes
|
||||
// 3. No race conditions from simultaneous updates
|
||||
// 4. Redis becomes eventually consistent as each node restores independently
|
||||
// 5. Simpler logic with better fault tolerance
|
||||
const recoverConnectionState = async (): Promise<void> => {
|
||||
if (isRedisRecoveryInProgress) {
|
||||
logger.debug("Redis recovery already in progress, skipping");
|
||||
return;
|
||||
}
|
||||
|
||||
isRedisRecoveryInProgress = true;
|
||||
logger.info("Starting Redis connection state recovery...");
|
||||
|
||||
try {
|
||||
// Each node simply restores its own local connections to Redis
|
||||
// This is the source of truth - no need for cross-node coordination
|
||||
await restoreLocalConnectionsToRedis();
|
||||
|
||||
logger.info("Redis connection state recovery completed - restored local state");
|
||||
} catch (error) {
|
||||
logger.error("Error during Redis recovery:", error);
|
||||
} finally {
|
||||
isRedisRecoveryInProgress = false;
|
||||
}
|
||||
};
|
||||
|
||||
const restoreLocalConnectionsToRedis = async (): Promise<void> => {
|
||||
if (!redisManager.isRedisEnabled()) return;
|
||||
|
||||
logger.info("Restoring local connections to Redis...");
|
||||
let restoredCount = 0;
|
||||
|
||||
try {
|
||||
// Restore all current local connections to Redis
|
||||
for (const [clientId, clients] of connectedClients.entries()) {
|
||||
const validClients = clients.filter(client => client.readyState === WebSocket.OPEN);
|
||||
|
||||
if (validClients.length > 0) {
|
||||
// Add this node to the client's connection list
|
||||
await redisManager.sadd(getConnectionsKey(clientId), NODE_ID);
|
||||
|
||||
// Store individual connection details
|
||||
for (const client of validClients) {
|
||||
if (client.connectionId) {
|
||||
await redisManager.hset(
|
||||
getNodeConnectionsKey(NODE_ID, clientId),
|
||||
client.connectionId,
|
||||
Date.now().toString()
|
||||
);
|
||||
}
|
||||
}
|
||||
restoredCount++;
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`Restored ${restoredCount} client connections to Redis`);
|
||||
} catch (error) {
|
||||
logger.error("Failed to restore local connections to Redis:", error);
|
||||
}
|
||||
};
|
||||
|
||||
// Helper functions for client management
|
||||
const addClient = async (
|
||||
clientType: ClientType,
|
||||
clientId: string,
|
||||
ws: AuthenticatedWebSocket
|
||||
): Promise<void> => {
|
||||
// Generate unique connection ID
|
||||
const connectionId = uuidv4();
|
||||
ws.connectionId = connectionId;
|
||||
|
||||
// Add to local tracking
|
||||
const mapKey = getClientMapKey(clientId);
|
||||
const existingClients = connectedClients.get(mapKey) || [];
|
||||
existingClients.push(ws);
|
||||
connectedClients.set(mapKey, existingClients);
|
||||
|
||||
// Add to Redis tracking if enabled
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
await redisManager.sadd(getConnectionsKey(clientId), NODE_ID);
|
||||
await redisManager.hset(
|
||||
getNodeConnectionsKey(NODE_ID, clientId),
|
||||
connectionId,
|
||||
Date.now().toString()
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to add client to Redis tracking (connection still functional locally):", error);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`
|
||||
);
|
||||
};
|
||||
|
||||
const removeClient = async (
|
||||
clientType: ClientType,
|
||||
clientId: string,
|
||||
ws: AuthenticatedWebSocket
|
||||
): Promise<void> => {
|
||||
const mapKey = getClientMapKey(clientId);
|
||||
const existingClients = connectedClients.get(mapKey) || [];
|
||||
const updatedClients = existingClients.filter((client) => client !== ws);
|
||||
if (updatedClients.length === 0) {
|
||||
connectedClients.delete(mapKey);
|
||||
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
|
||||
await redisManager.del(getNodeConnectionsKey(NODE_ID, clientId));
|
||||
} catch (error) {
|
||||
logger.error("Failed to remove client from Redis tracking (cleanup will occur on recovery):", error);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`
|
||||
);
|
||||
} else {
|
||||
connectedClients.set(mapKey, updatedClients);
|
||||
|
||||
if (redisManager.isRedisEnabled() && ws.connectionId) {
|
||||
try {
|
||||
await redisManager.hdel(
|
||||
getNodeConnectionsKey(NODE_ID, clientId),
|
||||
ws.connectionId
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to remove specific connection from Redis tracking:", error);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Local message sending (within this node)
|
||||
const sendToClientLocal = async (
|
||||
clientId: string,
|
||||
message: WSMessage
|
||||
): Promise<boolean> => {
|
||||
const mapKey = getClientMapKey(clientId);
|
||||
const clients = connectedClients.get(mapKey);
|
||||
if (!clients || clients.length === 0) {
|
||||
return false;
|
||||
}
|
||||
const messageString = JSON.stringify(message);
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(messageString);
|
||||
}
|
||||
});
|
||||
return true;
|
||||
};
|
||||
|
||||
const broadcastToAllExceptLocal = async (
|
||||
message: WSMessage,
|
||||
excludeClientId?: string
|
||||
): Promise<void> => {
|
||||
connectedClients.forEach((clients, mapKey) => {
|
||||
const [type, id] = mapKey.split(":");
|
||||
if (!(excludeClientId && id === excludeClientId)) {
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(JSON.stringify(message));
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// Cross-node message sending (via Redis)
|
||||
const sendToClient = async (
|
||||
clientId: string,
|
||||
message: WSMessage
|
||||
): Promise<boolean> => {
|
||||
// Try to send locally first
|
||||
const localSent = await sendToClientLocal(clientId, message);
|
||||
|
||||
// Only send via Redis if the client is not connected locally and Redis is enabled
|
||||
if (!localSent && redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
const redisMessage: RedisMessage = {
|
||||
type: "direct",
|
||||
targetClientId: clientId,
|
||||
message,
|
||||
fromNodeId: NODE_ID
|
||||
};
|
||||
|
||||
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
|
||||
} catch (error) {
|
||||
logger.error("Failed to send message via Redis, message may be lost:", error);
|
||||
// Continue execution - local delivery already attempted
|
||||
}
|
||||
} else if (!localSent && !redisManager.isRedisEnabled()) {
|
||||
// Redis is disabled or unavailable - log that we couldn't deliver to remote nodes
|
||||
logger.debug(`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`);
|
||||
}
|
||||
|
||||
return localSent;
|
||||
};
|
||||
|
||||
const broadcastToAllExcept = async (
|
||||
message: WSMessage,
|
||||
excludeClientId?: string
|
||||
): Promise<void> => {
|
||||
// Broadcast locally
|
||||
await broadcastToAllExceptLocal(message, excludeClientId);
|
||||
|
||||
// If Redis is enabled, also broadcast via Redis pub/sub to other nodes
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
const redisMessage: RedisMessage = {
|
||||
type: "broadcast",
|
||||
excludeClientId,
|
||||
message,
|
||||
fromNodeId: NODE_ID
|
||||
};
|
||||
|
||||
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
|
||||
} catch (error) {
|
||||
logger.error("Failed to broadcast message via Redis, remote nodes may not receive it:", error);
|
||||
// Continue execution - local broadcast already completed
|
||||
}
|
||||
} else {
|
||||
logger.debug("Redis unavailable - broadcast limited to local node only");
|
||||
}
|
||||
};
|
||||
|
||||
// Check if a client has active connections across all nodes
|
||||
const hasActiveConnections = async (clientId: string): Promise<boolean> => {
|
||||
if (!redisManager.isRedisEnabled()) {
|
||||
const mapKey = getClientMapKey(clientId);
|
||||
const clients = connectedClients.get(mapKey);
|
||||
return !!(clients && clients.length > 0);
|
||||
}
|
||||
|
||||
const activeNodes = await redisManager.smembers(
|
||||
getConnectionsKey(clientId)
|
||||
);
|
||||
return activeNodes.length > 0;
|
||||
};
|
||||
|
||||
// Get all active nodes for a client
|
||||
const getActiveNodes = async (
|
||||
clientType: ClientType,
|
||||
clientId: string
|
||||
): Promise<string[]> => {
|
||||
if (!redisManager.isRedisEnabled()) {
|
||||
const mapKey = getClientMapKey(clientId);
|
||||
const clients = connectedClients.get(mapKey);
|
||||
return clients && clients.length > 0 ? [NODE_ID] : [];
|
||||
}
|
||||
|
||||
return await redisManager.smembers(getConnectionsKey(clientId));
|
||||
};
|
||||
|
||||
// Token verification middleware
|
||||
const verifyToken = async (
|
||||
token: string,
|
||||
clientType: ClientType
|
||||
): Promise<TokenPayload | null> => {
|
||||
try {
|
||||
if (clientType === "newt") {
|
||||
const { session, newt } = await validateNewtSessionToken(token);
|
||||
if (!session || !newt) {
|
||||
return null;
|
||||
}
|
||||
const existingNewt = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.newtId, newt.newtId));
|
||||
if (!existingNewt || !existingNewt[0]) {
|
||||
return null;
|
||||
}
|
||||
return { client: existingNewt[0], session, clientType };
|
||||
} else if (clientType === "olm") {
|
||||
const { session, olm } = await validateOlmSessionToken(token);
|
||||
if (!session || !olm) {
|
||||
return null;
|
||||
}
|
||||
const existingOlm = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.olmId, olm.olmId));
|
||||
if (!existingOlm || !existingOlm[0]) {
|
||||
return null;
|
||||
}
|
||||
return { client: existingOlm[0], session, clientType };
|
||||
} else if (clientType === "remoteExitNode") {
|
||||
const { session, remoteExitNode } =
|
||||
await validateRemoteExitNodeSessionToken(token);
|
||||
if (!session || !remoteExitNode) {
|
||||
return null;
|
||||
}
|
||||
const existingRemoteExitNode = await db
|
||||
.select()
|
||||
.from(remoteExitNodes)
|
||||
.where(
|
||||
eq(
|
||||
remoteExitNodes.remoteExitNodeId,
|
||||
remoteExitNode.remoteExitNodeId
|
||||
)
|
||||
);
|
||||
if (!existingRemoteExitNode || !existingRemoteExitNode[0]) {
|
||||
return null;
|
||||
}
|
||||
return { client: existingRemoteExitNode[0], session, clientType };
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch (error) {
|
||||
logger.error("Token verification failed:", error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
const setupConnection = async (
|
||||
ws: AuthenticatedWebSocket,
|
||||
client: Newt | Olm | RemoteExitNode,
|
||||
clientType: ClientType
|
||||
): Promise<void> => {
|
||||
logger.info("Establishing websocket connection");
|
||||
if (!client) {
|
||||
logger.error("Connection attempt without client");
|
||||
return ws.terminate();
|
||||
}
|
||||
|
||||
ws.client = client;
|
||||
ws.clientType = clientType;
|
||||
ws.isFullyConnected = false;
|
||||
ws.pendingMessages = [];
|
||||
|
||||
// Get client ID first
|
||||
let clientId: string;
|
||||
if (clientType === "newt") {
|
||||
clientId = (client as Newt).newtId;
|
||||
} else if (clientType === "olm") {
|
||||
clientId = (client as Olm).olmId;
|
||||
} else if (clientType === "remoteExitNode") {
|
||||
clientId = (client as RemoteExitNode).remoteExitNodeId;
|
||||
} else {
|
||||
throw new Error(`Unknown client type: ${clientType}`);
|
||||
}
|
||||
|
||||
// Set up message handler FIRST to prevent race condition
|
||||
ws.on("message", async (data) => {
|
||||
if (!ws.isFullyConnected) {
|
||||
// Queue message for later processing with limits
|
||||
ws.pendingMessages = ws.pendingMessages || [];
|
||||
|
||||
if (ws.pendingMessages.length >= MAX_PENDING_MESSAGES) {
|
||||
logger.warn(
|
||||
`Too many pending messages for ${clientType.toUpperCase()} ID: ${clientId}, dropping oldest message`
|
||||
);
|
||||
ws.pendingMessages.shift(); // Remove oldest message
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`Queueing message from ${clientType.toUpperCase()} ID: ${clientId} (connection not fully established)`
|
||||
);
|
||||
ws.pendingMessages.push(data as Buffer);
|
||||
return;
|
||||
}
|
||||
|
||||
await processMessage(ws, data as Buffer, clientId, clientType);
|
||||
});
|
||||
|
||||
// Set up other event handlers before async operations
|
||||
ws.on("close", async () => {
|
||||
// Clear any pending messages to prevent memory leaks
|
||||
if (ws.pendingMessages) {
|
||||
ws.pendingMessages = [];
|
||||
}
|
||||
await removeClient(clientType, clientId, ws);
|
||||
logger.info(
|
||||
`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`
|
||||
);
|
||||
});
|
||||
|
||||
ws.on("error", (error: Error) => {
|
||||
logger.error(
|
||||
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,
|
||||
error
|
||||
);
|
||||
});
|
||||
|
||||
try {
|
||||
await addClient(clientType, clientId, ws);
|
||||
|
||||
// Mark connection as fully established
|
||||
ws.isFullyConnected = true;
|
||||
|
||||
logger.info(
|
||||
`WebSocket connection fully established and ready - ${clientType.toUpperCase()} ID: ${clientId}`
|
||||
);
|
||||
|
||||
// Process any messages that were queued while connection was being established
|
||||
await processPendingMessages(ws, clientId, clientType);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to fully establish connection for ${clientType.toUpperCase()} ID: ${clientId}:`,
|
||||
error
|
||||
);
|
||||
// ws.send(JSON.stringify({
|
||||
// type: "connection_error",
|
||||
// data: {
|
||||
// message: "Failed to establish connection"
|
||||
// }
|
||||
// }));
|
||||
ws.terminate();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Router endpoint
|
||||
router.get("/ws", (req: Request, res: Response) => {
|
||||
res.status(200).send("WebSocket endpoint");
|
||||
});
|
||||
|
||||
// WebSocket upgrade handler
|
||||
const handleWSUpgrade = (server: HttpServer): void => {
|
||||
server.on(
|
||||
"upgrade",
|
||||
async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
|
||||
try {
|
||||
const url = new URL(
|
||||
request.url || "",
|
||||
`http://${request.headers.host}`
|
||||
);
|
||||
const token =
|
||||
url.searchParams.get("token") ||
|
||||
request.headers["sec-websocket-protocol"] ||
|
||||
"";
|
||||
let clientType = url.searchParams.get(
|
||||
"clientType"
|
||||
) as ClientType;
|
||||
|
||||
if (!clientType) {
|
||||
clientType = "newt";
|
||||
}
|
||||
|
||||
if (
|
||||
!token ||
|
||||
!clientType ||
|
||||
!["newt", "olm", "remoteExitNode"].includes(clientType)
|
||||
) {
|
||||
logger.warn(
|
||||
"Unauthorized connection attempt: invalid token or client type..."
|
||||
);
|
||||
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
|
||||
socket.destroy();
|
||||
return;
|
||||
}
|
||||
|
||||
const tokenPayload = await verifyToken(token, clientType);
|
||||
if (!tokenPayload) {
|
||||
logger.debug(
|
||||
"Unauthorized connection attempt: invalid token..."
|
||||
);
|
||||
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
|
||||
socket.destroy();
|
||||
return;
|
||||
}
|
||||
|
||||
wss.handleUpgrade(
|
||||
request,
|
||||
socket,
|
||||
head,
|
||||
(ws: AuthenticatedWebSocket) => {
|
||||
setupConnection(
|
||||
ws,
|
||||
tokenPayload.client,
|
||||
tokenPayload.clientType
|
||||
);
|
||||
}
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("WebSocket upgrade error:", error);
|
||||
socket.write("HTTP/1.1 500 Internal Server Error\r\n\r\n");
|
||||
socket.destroy();
|
||||
}
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
// Add periodic connection state sync to handle Redis disconnections/reconnections
|
||||
const startPeriodicStateSync = (): void => {
|
||||
// Lightweight sync every 5 minutes - just restore our own state
|
||||
setInterval(async () => {
|
||||
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
|
||||
try {
|
||||
await restoreLocalConnectionsToRedis();
|
||||
logger.debug("Periodic connection state sync completed");
|
||||
} catch (error) {
|
||||
logger.error("Error during periodic connection state sync:", error);
|
||||
}
|
||||
}
|
||||
}, 5 * 60 * 1000); // 5 minutes
|
||||
|
||||
// Cleanup stale connections every 15 minutes
|
||||
setInterval(async () => {
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
await cleanupStaleConnections();
|
||||
logger.debug("Periodic connection cleanup completed");
|
||||
} catch (error) {
|
||||
logger.error("Error during periodic connection cleanup:", error);
|
||||
}
|
||||
}
|
||||
}, 15 * 60 * 1000); // 15 minutes
|
||||
};
|
||||
|
||||
const cleanupStaleConnections = async (): Promise<void> => {
|
||||
if (!redisManager.isRedisEnabled()) return;
|
||||
|
||||
try {
|
||||
const nodeKeys = await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`) || [];
|
||||
|
||||
for (const nodeKey of nodeKeys) {
|
||||
const connections = await redisManager.hgetall(nodeKey);
|
||||
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, '');
|
||||
const localClients = connectedClients.get(clientId) || [];
|
||||
const localConnectionIds = localClients
|
||||
.filter(client => client.readyState === WebSocket.OPEN)
|
||||
.map(client => client.connectionId)
|
||||
.filter(Boolean);
|
||||
|
||||
// Remove Redis entries for connections that no longer exist locally
|
||||
for (const [connectionId, timestamp] of Object.entries(connections)) {
|
||||
if (!localConnectionIds.includes(connectionId)) {
|
||||
await redisManager.hdel(nodeKey, connectionId);
|
||||
logger.debug(`Cleaned up stale connection: ${connectionId} for client: ${clientId}`);
|
||||
}
|
||||
}
|
||||
|
||||
// If no connections remain for this client, remove from Redis entirely
|
||||
const remainingConnections = await redisManager.hgetall(nodeKey);
|
||||
if (Object.keys(remainingConnections).length === 0) {
|
||||
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
|
||||
await redisManager.del(nodeKey);
|
||||
logger.debug(`Cleaned up empty connection tracking for client: ${clientId}`);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Error cleaning up stale connections:", error);
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize Redis subscription when the module is loaded
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
initializeRedisSubscription().catch((error) => {
|
||||
logger.error("Failed to initialize Redis subscription:", error);
|
||||
});
|
||||
|
||||
// Register recovery callback with Redis manager
|
||||
// When Redis reconnects, each node simply restores its own local state
|
||||
redisManager.onReconnection(async () => {
|
||||
logger.info("Redis reconnected, starting WebSocket state recovery...");
|
||||
await recoverConnectionState();
|
||||
});
|
||||
|
||||
// Start periodic state synchronization
|
||||
startPeriodicStateSync();
|
||||
|
||||
logger.info(
|
||||
`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`
|
||||
);
|
||||
} else {
|
||||
logger.debug(
|
||||
"WebSocket handler initialized in local mode"
|
||||
);
|
||||
}
|
||||
|
||||
// Cleanup function for graceful shutdown
|
||||
const cleanup = async (): Promise<void> => {
|
||||
try {
|
||||
// Close all WebSocket connections
|
||||
connectedClients.forEach((clients) => {
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.terminate();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Clean up Redis tracking for this node
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
const keys =
|
||||
(await redisManager
|
||||
.getClient()
|
||||
?.keys(`ws:node:${NODE_ID}:*`)) || [];
|
||||
if (keys.length > 0) {
|
||||
await Promise.all(keys.map((key) => redisManager.del(key)));
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("WebSocket cleanup completed");
|
||||
} catch (error) {
|
||||
logger.error("Error during WebSocket cleanup:", error);
|
||||
}
|
||||
};
|
||||
|
||||
export {
|
||||
router,
|
||||
handleWSUpgrade,
|
||||
sendToClient,
|
||||
broadcastToAllExcept,
|
||||
connectedClients,
|
||||
hasActiveConnections,
|
||||
getActiveNodes,
|
||||
NODE_ID,
|
||||
cleanup
|
||||
};
|
||||
Reference in New Issue
Block a user