mirror of
https://github.com/fosrl/pangolin.git
synced 2026-02-17 10:26:39 +00:00
Chungus
This commit is contained in:
@@ -7,16 +7,20 @@ import {
|
||||
errorHandlerMiddleware,
|
||||
notFoundMiddleware
|
||||
} from "@server/middlewares";
|
||||
import { corsWithLoginPageSupport } from "@server/middlewares/private/corsWithLoginPage";
|
||||
import { authenticated, unauthenticated } from "@server/routers/external";
|
||||
import { router as wsRouter, handleWSUpgrade } from "@server/routers/ws";
|
||||
import { logIncomingMiddleware } from "./middlewares/logIncoming";
|
||||
import { csrfProtectionMiddleware } from "./middlewares/csrfProtection";
|
||||
import helmet from "helmet";
|
||||
import { stripeWebhookHandler } from "@server/routers/private/billing/webhooks";
|
||||
import { build } from "./build";
|
||||
import rateLimit, { ipKeyGenerator } from "express-rate-limit";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "./types/HttpCode";
|
||||
import requestTimeoutMiddleware from "./middlewares/requestTimeout";
|
||||
import { createStore } from "./lib/rateLimitStore";
|
||||
import { createStore } from "@server/lib/private/rateLimitStore";
|
||||
import hybridRouter from "@server/routers/private/hybrid";
|
||||
|
||||
const dev = config.isDev;
|
||||
const externalPort = config.getRawConfig().server.external_port;
|
||||
@@ -30,26 +34,39 @@ export function createApiServer() {
|
||||
apiServer.set("trust proxy", trustProxy);
|
||||
}
|
||||
|
||||
if (build == "saas") {
|
||||
apiServer.post(
|
||||
`${prefix}/billing/webhooks`,
|
||||
express.raw({ type: "application/json" }),
|
||||
stripeWebhookHandler
|
||||
);
|
||||
}
|
||||
|
||||
const corsConfig = config.getRawConfig().server.cors;
|
||||
|
||||
const options = {
|
||||
...(corsConfig?.origins
|
||||
? { origin: corsConfig.origins }
|
||||
: {
|
||||
origin: (origin: any, callback: any) => {
|
||||
callback(null, true);
|
||||
}
|
||||
}),
|
||||
...(corsConfig?.methods && { methods: corsConfig.methods }),
|
||||
...(corsConfig?.allowed_headers && {
|
||||
allowedHeaders: corsConfig.allowed_headers
|
||||
}),
|
||||
credentials: !(corsConfig?.credentials === false)
|
||||
};
|
||||
if (build == "oss") {
|
||||
const options = {
|
||||
...(corsConfig?.origins
|
||||
? { origin: corsConfig.origins }
|
||||
: {
|
||||
origin: (origin: any, callback: any) => {
|
||||
callback(null, true);
|
||||
}
|
||||
}),
|
||||
...(corsConfig?.methods && { methods: corsConfig.methods }),
|
||||
...(corsConfig?.allowed_headers && {
|
||||
allowedHeaders: corsConfig.allowed_headers
|
||||
}),
|
||||
credentials: !(corsConfig?.credentials === false)
|
||||
};
|
||||
|
||||
logger.debug("Using CORS options", options);
|
||||
logger.debug("Using CORS options", options);
|
||||
|
||||
apiServer.use(cors(options));
|
||||
apiServer.use(cors(options));
|
||||
} else {
|
||||
// Use the custom CORS middleware with loginPage support
|
||||
apiServer.use(corsWithLoginPageSupport(corsConfig));
|
||||
}
|
||||
|
||||
if (!dev) {
|
||||
apiServer.use(helmet());
|
||||
@@ -70,7 +87,8 @@ export function createApiServer() {
|
||||
60 *
|
||||
1000,
|
||||
max: config.getRawConfig().rate_limits.global.max_requests,
|
||||
keyGenerator: (req) => `apiServerGlobal:${ipKeyGenerator(req.ip || "")}:${req.path}`,
|
||||
keyGenerator: (req) =>
|
||||
`apiServerGlobal:${ipKeyGenerator(req.ip || "")}:${req.path}`,
|
||||
handler: (req, res, next) => {
|
||||
const message = `Rate limit exceeded. You can make ${config.getRawConfig().rate_limits.global.max_requests} requests every ${config.getRawConfig().rate_limits.global.window_minutes} minute(s).`;
|
||||
return next(
|
||||
@@ -85,6 +103,9 @@ export function createApiServer() {
|
||||
// API routes
|
||||
apiServer.use(logIncomingMiddleware);
|
||||
apiServer.use(prefix, unauthenticated);
|
||||
if (build !== "oss") {
|
||||
apiServer.use(`${prefix}/hybrid`, hybridRouter);
|
||||
}
|
||||
apiServer.use(prefix, authenticated);
|
||||
|
||||
// WebSocket routes
|
||||
|
||||
@@ -4,6 +4,7 @@ import { userActions, roleActions, userOrgs } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { sendUsageNotification } from "@server/routers/org";
|
||||
|
||||
export enum ActionsEnum {
|
||||
createOrgUser = "createOrgUser",
|
||||
@@ -98,10 +99,23 @@ export enum ActionsEnum {
|
||||
listApiKeyActions = "listApiKeyActions",
|
||||
listApiKeys = "listApiKeys",
|
||||
getApiKey = "getApiKey",
|
||||
getCertificate = "getCertificate",
|
||||
restartCertificate = "restartCertificate",
|
||||
billing = "billing",
|
||||
createOrgDomain = "createOrgDomain",
|
||||
deleteOrgDomain = "deleteOrgDomain",
|
||||
restartOrgDomain = "restartOrgDomain",
|
||||
sendUsageNotification = "sendUsageNotification",
|
||||
createRemoteExitNode = "createRemoteExitNode",
|
||||
updateRemoteExitNode = "updateRemoteExitNode",
|
||||
getRemoteExitNode = "getRemoteExitNode",
|
||||
listRemoteExitNode = "listRemoteExitNode",
|
||||
deleteRemoteExitNode = "deleteRemoteExitNode",
|
||||
updateOrgUser = "updateOrgUser",
|
||||
createLoginPage = "createLoginPage",
|
||||
updateLoginPage = "updateLoginPage",
|
||||
getLoginPage = "getLoginPage",
|
||||
deleteLoginPage = "deleteLoginPage",
|
||||
applyBlueprint = "applyBlueprint"
|
||||
}
|
||||
|
||||
|
||||
@@ -3,13 +3,7 @@ import {
|
||||
encodeHexLowerCase
|
||||
} from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import {
|
||||
resourceSessions,
|
||||
Session,
|
||||
sessions,
|
||||
User,
|
||||
users
|
||||
} from "@server/db";
|
||||
import { resourceSessions, Session, sessions, User, users } from "@server/db";
|
||||
import { db } from "@server/db";
|
||||
import { eq, inArray } from "drizzle-orm";
|
||||
import config from "@server/lib/config";
|
||||
@@ -24,8 +18,9 @@ export const SESSION_COOKIE_EXPIRES =
|
||||
60 *
|
||||
60 *
|
||||
config.getRawConfig().server.dashboard_session_length_hours;
|
||||
export const COOKIE_DOMAIN = config.getRawConfig().app.dashboard_url ?
|
||||
"." + new URL(config.getRawConfig().app.dashboard_url!).hostname : undefined;
|
||||
export const COOKIE_DOMAIN = config.getRawConfig().app.dashboard_url
|
||||
? new URL(config.getRawConfig().app.dashboard_url!).hostname
|
||||
: undefined;
|
||||
|
||||
export function generateSessionToken(): string {
|
||||
const bytes = new Uint8Array(20);
|
||||
@@ -98,8 +93,8 @@ export async function invalidateSession(sessionId: string): Promise<void> {
|
||||
try {
|
||||
await db.transaction(async (trx) => {
|
||||
await trx
|
||||
.delete(resourceSessions)
|
||||
.where(eq(resourceSessions.userSessionId, sessionId));
|
||||
.delete(resourceSessions)
|
||||
.where(eq(resourceSessions.userSessionId, sessionId));
|
||||
await trx.delete(sessions).where(eq(sessions.sessionId, sessionId));
|
||||
});
|
||||
} catch (e) {
|
||||
@@ -111,9 +106,9 @@ export async function invalidateAllSessions(userId: string): Promise<void> {
|
||||
try {
|
||||
await db.transaction(async (trx) => {
|
||||
const userSessions = await trx
|
||||
.select()
|
||||
.from(sessions)
|
||||
.where(eq(sessions.userId, userId));
|
||||
.select()
|
||||
.from(sessions)
|
||||
.where(eq(sessions.userId, userId));
|
||||
await trx.delete(resourceSessions).where(
|
||||
inArray(
|
||||
resourceSessions.userSessionId,
|
||||
|
||||
85
server/auth/sessions/privateRemoteExitNode.ts
Normal file
85
server/auth/sessions/privateRemoteExitNode.ts
Normal file
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import {
|
||||
encodeHexLowerCase,
|
||||
} from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { RemoteExitNode, remoteExitNodes, remoteExitNodeSessions, RemoteExitNodeSession } from "@server/db";
|
||||
import { db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
|
||||
|
||||
export async function createRemoteExitNodeSession(
|
||||
token: string,
|
||||
remoteExitNodeId: string,
|
||||
): Promise<RemoteExitNodeSession> {
|
||||
const sessionId = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(token)),
|
||||
);
|
||||
const session: RemoteExitNodeSession = {
|
||||
sessionId: sessionId,
|
||||
remoteExitNodeId,
|
||||
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
|
||||
};
|
||||
await db.insert(remoteExitNodeSessions).values(session);
|
||||
return session;
|
||||
}
|
||||
|
||||
export async function validateRemoteExitNodeSessionToken(
|
||||
token: string,
|
||||
): Promise<SessionValidationResult> {
|
||||
const sessionId = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(token)),
|
||||
);
|
||||
const result = await db
|
||||
.select({ remoteExitNode: remoteExitNodes, session: remoteExitNodeSessions })
|
||||
.from(remoteExitNodeSessions)
|
||||
.innerJoin(remoteExitNodes, eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodes.remoteExitNodeId))
|
||||
.where(eq(remoteExitNodeSessions.sessionId, sessionId));
|
||||
if (result.length < 1) {
|
||||
return { session: null, remoteExitNode: null };
|
||||
}
|
||||
const { remoteExitNode, session } = result[0];
|
||||
if (Date.now() >= session.expiresAt) {
|
||||
await db
|
||||
.delete(remoteExitNodeSessions)
|
||||
.where(eq(remoteExitNodeSessions.sessionId, session.sessionId));
|
||||
return { session: null, remoteExitNode: null };
|
||||
}
|
||||
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
|
||||
session.expiresAt = new Date(
|
||||
Date.now() + EXPIRES,
|
||||
).getTime();
|
||||
await db
|
||||
.update(remoteExitNodeSessions)
|
||||
.set({
|
||||
expiresAt: session.expiresAt,
|
||||
})
|
||||
.where(eq(remoteExitNodeSessions.sessionId, session.sessionId));
|
||||
}
|
||||
return { session, remoteExitNode };
|
||||
}
|
||||
|
||||
export async function invalidateRemoteExitNodeSession(sessionId: string): Promise<void> {
|
||||
await db.delete(remoteExitNodeSessions).where(eq(remoteExitNodeSessions.sessionId, sessionId));
|
||||
}
|
||||
|
||||
export async function invalidateAllRemoteExitNodeSessions(remoteExitNodeId: string): Promise<void> {
|
||||
await db.delete(remoteExitNodeSessions).where(eq(remoteExitNodeSessions.remoteExitNodeId, remoteExitNodeId));
|
||||
}
|
||||
|
||||
export type SessionValidationResult =
|
||||
| { session: RemoteExitNodeSession; remoteExitNode: RemoteExitNode }
|
||||
| { session: null; remoteExitNode: null };
|
||||
@@ -199,14 +199,14 @@ export function serializeResourceSessionCookie(
|
||||
const now = new Date().getTime();
|
||||
if (!isHttp) {
|
||||
if (expiresAt === undefined) {
|
||||
return `${cookieName}_s.${now}=${token}; HttpOnly; SameSite=Lax; Path=/; Secure; Domain=${"." + domain}`;
|
||||
return `${cookieName}_s.${now}=${token}; HttpOnly; SameSite=Lax; Path=/; Secure; Domain=${domain}`;
|
||||
}
|
||||
return `${cookieName}_s.${now}=${token}; HttpOnly; SameSite=Lax; Expires=${expiresAt.toUTCString()}; Path=/; Secure; Domain=${"." + domain}`;
|
||||
return `${cookieName}_s.${now}=${token}; HttpOnly; SameSite=Lax; Expires=${expiresAt.toUTCString()}; Path=/; Secure; Domain=${domain}`;
|
||||
} else {
|
||||
if (expiresAt === undefined) {
|
||||
return `${cookieName}.${now}=${token}; HttpOnly; SameSite=Lax; Path=/; Domain=${"." + domain}`;
|
||||
return `${cookieName}.${now}=${token}; HttpOnly; SameSite=Lax; Path=/; Domain=$domain}`;
|
||||
}
|
||||
return `${cookieName}.${now}=${token}; HttpOnly; SameSite=Lax; Expires=${expiresAt.toUTCString()}; Path=/; Domain=${"." + domain}`;
|
||||
return `${cookieName}.${now}=${token}; HttpOnly; SameSite=Lax; Expires=${expiresAt.toUTCString()}; Path=/; Domain=${domain}`;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,9 +216,9 @@ export function createBlankResourceSessionTokenCookie(
|
||||
isHttp: boolean = false
|
||||
): string {
|
||||
if (!isHttp) {
|
||||
return `${cookieName}_s=; HttpOnly; SameSite=Lax; Max-Age=0; Path=/; Secure; Domain=${"." + domain}`;
|
||||
return `${cookieName}_s=; HttpOnly; SameSite=Lax; Max-Age=0; Path=/; Secure; Domain=${domain}`;
|
||||
} else {
|
||||
return `${cookieName}=; HttpOnly; SameSite=Lax; Max-Age=0; Path=/; Domain=${"." + domain}`;
|
||||
return `${cookieName}=; HttpOnly; SameSite=Lax; Max-Age=0; Path=/; Domain=${domain}`;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
export const build = "oss" as any;
|
||||
1014
server/db/countries.ts
Normal file
1014
server/db/countries.ts
Normal file
File diff suppressed because it is too large
Load Diff
13
server/db/maxmind.ts
Normal file
13
server/db/maxmind.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
import maxmind, { CountryResponse, Reader } from "maxmind";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
let maxmindLookup: Reader<CountryResponse> | null;
|
||||
if (config.getRawConfig().server.maxmind_db_path) {
|
||||
maxmindLookup = await maxmind.open<CountryResponse>(
|
||||
config.getRawConfig().server.maxmind_db_path!
|
||||
);
|
||||
} else {
|
||||
maxmindLookup = null;
|
||||
}
|
||||
|
||||
export { maxmindLookup };
|
||||
@@ -39,7 +39,7 @@ function createDb() {
|
||||
connectionString,
|
||||
max: 20,
|
||||
idleTimeoutMillis: 30000,
|
||||
connectionTimeoutMillis: 2000,
|
||||
connectionTimeoutMillis: 5000,
|
||||
});
|
||||
|
||||
const replicas = [];
|
||||
@@ -52,7 +52,7 @@ function createDb() {
|
||||
connectionString: conn.connection_string,
|
||||
max: 10,
|
||||
idleTimeoutMillis: 30000,
|
||||
connectionTimeoutMillis: 2000,
|
||||
connectionTimeoutMillis: 5000,
|
||||
});
|
||||
replicas.push(DrizzlePostgres(replicaPool));
|
||||
}
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
export * from "./driver";
|
||||
export * from "./schema";
|
||||
export * from "./schema";
|
||||
export * from "./privateSchema";
|
||||
|
||||
245
server/db/pg/privateSchema.ts
Normal file
245
server/db/pg/privateSchema.ts
Normal file
@@ -0,0 +1,245 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import {
|
||||
pgTable,
|
||||
serial,
|
||||
varchar,
|
||||
boolean,
|
||||
integer,
|
||||
bigint,
|
||||
real,
|
||||
text
|
||||
} from "drizzle-orm/pg-core";
|
||||
import { InferSelectModel } from "drizzle-orm";
|
||||
import { domains, orgs, targets, users, exitNodes, sessions } from "./schema";
|
||||
|
||||
export const certificates = pgTable("certificates", {
|
||||
certId: serial("certId").primaryKey(),
|
||||
domain: varchar("domain", { length: 255 }).notNull().unique(),
|
||||
domainId: varchar("domainId").references(() => domains.domainId, {
|
||||
onDelete: "cascade"
|
||||
}),
|
||||
wildcard: boolean("wildcard").default(false),
|
||||
status: varchar("status", { length: 50 }).notNull().default("pending"), // pending, requested, valid, expired, failed
|
||||
expiresAt: bigint("expiresAt", { mode: "number" }),
|
||||
lastRenewalAttempt: bigint("lastRenewalAttempt", { mode: "number" }),
|
||||
createdAt: bigint("createdAt", { mode: "number" }).notNull(),
|
||||
updatedAt: bigint("updatedAt", { mode: "number" }).notNull(),
|
||||
orderId: varchar("orderId", { length: 500 }),
|
||||
errorMessage: text("errorMessage"),
|
||||
renewalCount: integer("renewalCount").default(0),
|
||||
certFile: text("certFile"),
|
||||
keyFile: text("keyFile")
|
||||
});
|
||||
|
||||
export const dnsChallenge = pgTable("dnsChallenges", {
|
||||
dnsChallengeId: serial("dnsChallengeId").primaryKey(),
|
||||
domain: varchar("domain", { length: 255 }).notNull(),
|
||||
token: varchar("token", { length: 255 }).notNull(),
|
||||
keyAuthorization: varchar("keyAuthorization", { length: 1000 }).notNull(),
|
||||
createdAt: bigint("createdAt", { mode: "number" }).notNull(),
|
||||
expiresAt: bigint("expiresAt", { mode: "number" }).notNull(),
|
||||
completed: boolean("completed").default(false)
|
||||
});
|
||||
|
||||
export const account = pgTable("account", {
|
||||
accountId: serial("accountId").primaryKey(),
|
||||
userId: varchar("userId")
|
||||
.notNull()
|
||||
.references(() => users.userId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const customers = pgTable("customers", {
|
||||
customerId: varchar("customerId", { length: 255 }).primaryKey().notNull(),
|
||||
orgId: varchar("orgId", { length: 255 })
|
||||
.notNull()
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" }),
|
||||
// accountId: integer("accountId")
|
||||
// .references(() => account.accountId, { onDelete: "cascade" }), // Optional, if using accounts
|
||||
email: varchar("email", { length: 255 }),
|
||||
name: varchar("name", { length: 255 }),
|
||||
phone: varchar("phone", { length: 50 }),
|
||||
address: text("address"),
|
||||
createdAt: bigint("createdAt", { mode: "number" }).notNull(),
|
||||
updatedAt: bigint("updatedAt", { mode: "number" }).notNull()
|
||||
});
|
||||
|
||||
export const subscriptions = pgTable("subscriptions", {
|
||||
subscriptionId: varchar("subscriptionId", { length: 255 })
|
||||
.primaryKey()
|
||||
.notNull(),
|
||||
customerId: varchar("customerId", { length: 255 })
|
||||
.notNull()
|
||||
.references(() => customers.customerId, { onDelete: "cascade" }),
|
||||
status: varchar("status", { length: 50 }).notNull().default("active"), // active, past_due, canceled, unpaid
|
||||
canceledAt: bigint("canceledAt", { mode: "number" }),
|
||||
createdAt: bigint("createdAt", { mode: "number" }).notNull(),
|
||||
updatedAt: bigint("updatedAt", { mode: "number" }),
|
||||
billingCycleAnchor: bigint("billingCycleAnchor", { mode: "number" })
|
||||
});
|
||||
|
||||
export const subscriptionItems = pgTable("subscriptionItems", {
|
||||
subscriptionItemId: serial("subscriptionItemId").primaryKey(),
|
||||
subscriptionId: varchar("subscriptionId", { length: 255 })
|
||||
.notNull()
|
||||
.references(() => subscriptions.subscriptionId, {
|
||||
onDelete: "cascade"
|
||||
}),
|
||||
planId: varchar("planId", { length: 255 }).notNull(),
|
||||
priceId: varchar("priceId", { length: 255 }),
|
||||
meterId: varchar("meterId", { length: 255 }),
|
||||
unitAmount: real("unitAmount"),
|
||||
tiers: text("tiers"),
|
||||
interval: varchar("interval", { length: 50 }),
|
||||
currentPeriodStart: bigint("currentPeriodStart", { mode: "number" }),
|
||||
currentPeriodEnd: bigint("currentPeriodEnd", { mode: "number" }),
|
||||
name: varchar("name", { length: 255 })
|
||||
});
|
||||
|
||||
export const accountDomains = pgTable("accountDomains", {
|
||||
accountId: integer("accountId")
|
||||
.notNull()
|
||||
.references(() => account.accountId, { onDelete: "cascade" }),
|
||||
domainId: varchar("domainId")
|
||||
.notNull()
|
||||
.references(() => domains.domainId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const usage = pgTable("usage", {
|
||||
usageId: varchar("usageId", { length: 255 }).primaryKey(),
|
||||
featureId: varchar("featureId", { length: 255 }).notNull(),
|
||||
orgId: varchar("orgId")
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" })
|
||||
.notNull(),
|
||||
meterId: varchar("meterId", { length: 255 }),
|
||||
instantaneousValue: real("instantaneousValue"),
|
||||
latestValue: real("latestValue").notNull(),
|
||||
previousValue: real("previousValue"),
|
||||
updatedAt: bigint("updatedAt", { mode: "number" }).notNull(),
|
||||
rolledOverAt: bigint("rolledOverAt", { mode: "number" }),
|
||||
nextRolloverAt: bigint("nextRolloverAt", { mode: "number" })
|
||||
});
|
||||
|
||||
export const limits = pgTable("limits", {
|
||||
limitId: varchar("limitId", { length: 255 }).primaryKey(),
|
||||
featureId: varchar("featureId", { length: 255 }).notNull(),
|
||||
orgId: varchar("orgId")
|
||||
.references(() => orgs.orgId, {
|
||||
onDelete: "cascade"
|
||||
})
|
||||
.notNull(),
|
||||
value: real("value"),
|
||||
description: text("description")
|
||||
});
|
||||
|
||||
export const usageNotifications = pgTable("usageNotifications", {
|
||||
notificationId: serial("notificationId").primaryKey(),
|
||||
orgId: varchar("orgId")
|
||||
.notNull()
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" }),
|
||||
featureId: varchar("featureId", { length: 255 }).notNull(),
|
||||
limitId: varchar("limitId", { length: 255 }).notNull(),
|
||||
notificationType: varchar("notificationType", { length: 50 }).notNull(),
|
||||
sentAt: bigint("sentAt", { mode: "number" }).notNull()
|
||||
});
|
||||
|
||||
export const domainNamespaces = pgTable("domainNamespaces", {
|
||||
domainNamespaceId: varchar("domainNamespaceId", {
|
||||
length: 255
|
||||
}).primaryKey(),
|
||||
domainId: varchar("domainId")
|
||||
.references(() => domains.domainId, {
|
||||
onDelete: "set null"
|
||||
})
|
||||
.notNull()
|
||||
});
|
||||
|
||||
export const exitNodeOrgs = pgTable("exitNodeOrgs", {
|
||||
exitNodeId: integer("exitNodeId")
|
||||
.notNull()
|
||||
.references(() => exitNodes.exitNodeId, { onDelete: "cascade" }),
|
||||
orgId: text("orgId")
|
||||
.notNull()
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const remoteExitNodes = pgTable("remoteExitNode", {
|
||||
remoteExitNodeId: varchar("id").primaryKey(),
|
||||
secretHash: varchar("secretHash").notNull(),
|
||||
dateCreated: varchar("dateCreated").notNull(),
|
||||
version: varchar("version"),
|
||||
exitNodeId: integer("exitNodeId").references(() => exitNodes.exitNodeId, {
|
||||
onDelete: "cascade"
|
||||
})
|
||||
});
|
||||
|
||||
export const remoteExitNodeSessions = pgTable("remoteExitNodeSession", {
|
||||
sessionId: varchar("id").primaryKey(),
|
||||
remoteExitNodeId: varchar("remoteExitNodeId")
|
||||
.notNull()
|
||||
.references(() => remoteExitNodes.remoteExitNodeId, {
|
||||
onDelete: "cascade"
|
||||
}),
|
||||
expiresAt: bigint("expiresAt", { mode: "number" }).notNull()
|
||||
});
|
||||
|
||||
export const loginPage = pgTable("loginPage", {
|
||||
loginPageId: serial("loginPageId").primaryKey(),
|
||||
subdomain: varchar("subdomain"),
|
||||
fullDomain: varchar("fullDomain"),
|
||||
exitNodeId: integer("exitNodeId").references(() => exitNodes.exitNodeId, {
|
||||
onDelete: "set null"
|
||||
}),
|
||||
domainId: varchar("domainId").references(() => domains.domainId, {
|
||||
onDelete: "set null"
|
||||
})
|
||||
});
|
||||
|
||||
export const loginPageOrg = pgTable("loginPageOrg", {
|
||||
loginPageId: integer("loginPageId")
|
||||
.notNull()
|
||||
.references(() => loginPage.loginPageId, { onDelete: "cascade" }),
|
||||
orgId: varchar("orgId")
|
||||
.notNull()
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const sessionTransferToken = pgTable("sessionTransferToken", {
|
||||
token: varchar("token").primaryKey(),
|
||||
sessionId: varchar("sessionId")
|
||||
.notNull()
|
||||
.references(() => sessions.sessionId, {
|
||||
onDelete: "cascade"
|
||||
}),
|
||||
encryptedSession: text("encryptedSession").notNull(),
|
||||
expiresAt: bigint("expiresAt", { mode: "number" }).notNull()
|
||||
});
|
||||
|
||||
export type Limit = InferSelectModel<typeof limits>;
|
||||
export type Account = InferSelectModel<typeof account>;
|
||||
export type Certificate = InferSelectModel<typeof certificates>;
|
||||
export type DnsChallenge = InferSelectModel<typeof dnsChallenge>;
|
||||
export type Customer = InferSelectModel<typeof customers>;
|
||||
export type Subscription = InferSelectModel<typeof subscriptions>;
|
||||
export type SubscriptionItem = InferSelectModel<typeof subscriptionItems>;
|
||||
export type Usage = InferSelectModel<typeof usage>;
|
||||
export type UsageLimit = InferSelectModel<typeof limits>;
|
||||
export type AccountDomain = InferSelectModel<typeof accountDomains>;
|
||||
export type UsageNotification = InferSelectModel<typeof usageNotifications>;
|
||||
export type RemoteExitNode = InferSelectModel<typeof remoteExitNodes>;
|
||||
export type RemoteExitNodeSession = InferSelectModel<
|
||||
typeof remoteExitNodeSessions
|
||||
>;
|
||||
export type ExitNodeOrg = InferSelectModel<typeof exitNodeOrgs>;
|
||||
export type LoginPage = InferSelectModel<typeof loginPage>;
|
||||
@@ -128,6 +128,27 @@ export const targets = pgTable("targets", {
|
||||
rewritePathType: text("rewritePathType") // exact, prefix, regex, stripPrefix
|
||||
});
|
||||
|
||||
export const targetHealthCheck = pgTable("targetHealthCheck", {
|
||||
targetHealthCheckId: serial("targetHealthCheckId").primaryKey(),
|
||||
targetId: integer("targetId")
|
||||
.notNull()
|
||||
.references(() => targets.targetId, { onDelete: "cascade" }),
|
||||
hcEnabled: boolean("hcEnabled").notNull().default(false),
|
||||
hcPath: varchar("hcPath"),
|
||||
hcScheme: varchar("hcScheme"),
|
||||
hcMode: varchar("hcMode").default("http"),
|
||||
hcHostname: varchar("hcHostname"),
|
||||
hcPort: integer("hcPort"),
|
||||
hcInterval: integer("hcInterval").default(30), // in seconds
|
||||
hcUnhealthyInterval: integer("hcUnhealthyInterval").default(30), // in seconds
|
||||
hcTimeout: integer("hcTimeout").default(5), // in seconds
|
||||
hcHeaders: varchar("hcHeaders"),
|
||||
hcFollowRedirects: boolean("hcFollowRedirects").default(true),
|
||||
hcMethod: varchar("hcMethod").default("GET"),
|
||||
hcStatus: integer("hcStatus"), // http code
|
||||
hcHealth: text("hcHealth").default("unknown") // "unknown", "healthy", "unhealthy"
|
||||
});
|
||||
|
||||
export const exitNodes = pgTable("exitNodes", {
|
||||
exitNodeId: serial("exitNodeId").primaryKey(),
|
||||
name: varchar("name").notNull(),
|
||||
@@ -689,3 +710,4 @@ export type OrgDomains = InferSelectModel<typeof orgDomains>;
|
||||
export type SiteResource = InferSelectModel<typeof siteResources>;
|
||||
export type SetupToken = InferSelectModel<typeof setupTokens>;
|
||||
export type HostMeta = InferSelectModel<typeof hostMeta>;
|
||||
export type TargetHealthCheck = InferSelectModel<typeof targetHealthCheck>;
|
||||
202
server/db/private/rateLimit.test.ts
Normal file
202
server/db/private/rateLimit.test.ts
Normal file
@@ -0,0 +1,202 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
// Simple test file for the rate limit service with Redis
|
||||
// Run with: npx ts-node rateLimitService.test.ts
|
||||
|
||||
import { RateLimitService } from './rateLimit';
|
||||
|
||||
function generateClientId() {
|
||||
return 'client-' + Math.random().toString(36).substring(2, 15);
|
||||
}
|
||||
|
||||
async function runTests() {
|
||||
console.log('Starting Rate Limit Service Tests...\n');
|
||||
|
||||
const rateLimitService = new RateLimitService();
|
||||
let testsPassed = 0;
|
||||
let testsTotal = 0;
|
||||
|
||||
// Helper function to run a test
|
||||
async function test(name: string, testFn: () => Promise<void>) {
|
||||
testsTotal++;
|
||||
try {
|
||||
await testFn();
|
||||
console.log(`✅ ${name}`);
|
||||
testsPassed++;
|
||||
} catch (error) {
|
||||
console.log(`❌ ${name}: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for assertions
|
||||
function assert(condition: boolean, message: string) {
|
||||
if (!condition) {
|
||||
throw new Error(message);
|
||||
}
|
||||
}
|
||||
|
||||
// Test 1: Basic rate limiting
|
||||
await test('Should allow requests under the limit', async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 5;
|
||||
|
||||
for (let i = 0; i < maxRequests - 1; i++) {
|
||||
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
|
||||
assert(result.totalHits === i + 1, `Expected ${i + 1} hits, got ${result.totalHits}`);
|
||||
}
|
||||
});
|
||||
|
||||
// Test 2: Rate limit blocking
|
||||
await test('Should block requests over the limit', async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 30;
|
||||
|
||||
// Use up all allowed requests
|
||||
for (let i = 0; i < maxRequests - 1; i++) {
|
||||
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(!result.isLimited, `Request ${i + 1} should be allowed`);
|
||||
}
|
||||
|
||||
// Next request should be blocked
|
||||
const blockedResult = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(blockedResult.isLimited, 'Request should be blocked');
|
||||
assert(blockedResult.reason === 'global', 'Should be blocked for global reason');
|
||||
});
|
||||
|
||||
// Test 3: Message type limits
|
||||
await test('Should handle message type limits', async () => {
|
||||
const clientId = generateClientId();
|
||||
const globalMax = 10;
|
||||
const messageTypeMax = 2;
|
||||
|
||||
// Send messages of type 'ping' up to the limit
|
||||
for (let i = 0; i < messageTypeMax - 1; i++) {
|
||||
const result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
'ping',
|
||||
globalMax,
|
||||
messageTypeMax
|
||||
);
|
||||
assert(!result.isLimited, `Ping message ${i + 1} should be allowed`);
|
||||
}
|
||||
|
||||
// Next 'ping' should be blocked
|
||||
const blockedResult = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
'ping',
|
||||
globalMax,
|
||||
messageTypeMax
|
||||
);
|
||||
assert(blockedResult.isLimited, 'Ping message should be blocked');
|
||||
assert(blockedResult.reason === 'message_type:ping', 'Should be blocked for message type');
|
||||
|
||||
// Other message types should still work
|
||||
const otherResult = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
'pong',
|
||||
globalMax,
|
||||
messageTypeMax
|
||||
);
|
||||
assert(!otherResult.isLimited, 'Pong message should be allowed');
|
||||
});
|
||||
|
||||
// Test 4: Reset functionality
|
||||
await test('Should reset client correctly', async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 3;
|
||||
|
||||
// Use up some requests
|
||||
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
await rateLimitService.checkRateLimit(clientId, 'test', maxRequests);
|
||||
|
||||
// Reset the client
|
||||
await rateLimitService.resetKey(clientId);
|
||||
|
||||
// Should be able to make fresh requests
|
||||
const result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(!result.isLimited, 'Request after reset should be allowed');
|
||||
assert(result.totalHits === 1, 'Should have 1 hit after reset');
|
||||
});
|
||||
|
||||
// Test 5: Different clients are independent
|
||||
await test('Should handle different clients independently', async () => {
|
||||
const client1 = generateClientId();
|
||||
const client2 = generateClientId();
|
||||
const maxRequests = 2;
|
||||
|
||||
// Client 1 uses up their limit
|
||||
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
|
||||
await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
|
||||
const client1Blocked = await rateLimitService.checkRateLimit(client1, undefined, maxRequests);
|
||||
assert(client1Blocked.isLimited, 'Client 1 should be blocked');
|
||||
|
||||
// Client 2 should still be able to make requests
|
||||
const client2Result = await rateLimitService.checkRateLimit(client2, undefined, maxRequests);
|
||||
assert(!client2Result.isLimited, 'Client 2 should not be blocked');
|
||||
assert(client2Result.totalHits === 1, 'Client 2 should have 1 hit');
|
||||
});
|
||||
|
||||
// Test 6: Decrement functionality
|
||||
await test('Should decrement correctly', async () => {
|
||||
const clientId = generateClientId();
|
||||
const maxRequests = 5;
|
||||
|
||||
// Make some requests
|
||||
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
let result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(result.totalHits === 3, 'Should have 3 hits before decrement');
|
||||
|
||||
// Decrement
|
||||
await rateLimitService.decrementRateLimit(clientId);
|
||||
|
||||
// Next request should reflect the decrement
|
||||
result = await rateLimitService.checkRateLimit(clientId, undefined, maxRequests);
|
||||
assert(result.totalHits === 3, 'Should have 3 hits after decrement + increment');
|
||||
});
|
||||
|
||||
// Wait a moment for any pending Redis operations
|
||||
console.log('\nWaiting for Redis sync...');
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
|
||||
// Force sync to test Redis integration
|
||||
await test('Should sync to Redis', async () => {
|
||||
await rateLimitService.forceSyncAllPendingData();
|
||||
// If this doesn't throw, Redis sync is working
|
||||
assert(true, 'Redis sync completed');
|
||||
});
|
||||
|
||||
// Cleanup
|
||||
await rateLimitService.cleanup();
|
||||
|
||||
// Results
|
||||
console.log(`\n--- Test Results ---`);
|
||||
console.log(`✅ Passed: ${testsPassed}/${testsTotal}`);
|
||||
console.log(`❌ Failed: ${testsTotal - testsPassed}/${testsTotal}`);
|
||||
|
||||
if (testsPassed === testsTotal) {
|
||||
console.log('\n🎉 All tests passed!');
|
||||
process.exit(0);
|
||||
} else {
|
||||
console.log('\n💥 Some tests failed!');
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Run the tests
|
||||
runTests().catch(error => {
|
||||
console.error('Test runner error:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
458
server/db/private/rateLimit.ts
Normal file
458
server/db/private/rateLimit.ts
Normal file
@@ -0,0 +1,458 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import logger from "@server/logger";
|
||||
import redisManager from "@server/db/private/redis";
|
||||
import { build } from "@server/build";
|
||||
|
||||
// Rate limiting configuration
|
||||
export const RATE_LIMIT_WINDOW = 60; // 1 minute in seconds
|
||||
export const RATE_LIMIT_MAX_REQUESTS = 100;
|
||||
export const RATE_LIMIT_PER_MESSAGE_TYPE = 20; // Per message type limit within the window
|
||||
|
||||
// Configuration for batched Redis sync
|
||||
export const REDIS_SYNC_THRESHOLD = 15; // Sync to Redis every N messages
|
||||
export const REDIS_SYNC_FORCE_INTERVAL = 30000; // Force sync every 30 seconds as backup
|
||||
|
||||
interface RateLimitTracker {
|
||||
count: number;
|
||||
windowStart: number;
|
||||
pendingCount: number;
|
||||
lastSyncedCount: number;
|
||||
}
|
||||
|
||||
interface RateLimitResult {
|
||||
isLimited: boolean;
|
||||
reason?: string;
|
||||
totalHits?: number;
|
||||
resetTime?: Date;
|
||||
}
|
||||
|
||||
export class RateLimitService {
|
||||
private localRateLimitTracker: Map<string, RateLimitTracker> = new Map();
|
||||
private localMessageTypeRateLimitTracker: Map<string, RateLimitTracker> = new Map();
|
||||
private cleanupInterval: NodeJS.Timeout | null = null;
|
||||
private forceSyncInterval: NodeJS.Timeout | null = null;
|
||||
|
||||
constructor() {
|
||||
if (build == "oss") {
|
||||
return;
|
||||
}
|
||||
|
||||
// Start cleanup and sync intervals
|
||||
this.cleanupInterval = setInterval(() => {
|
||||
this.cleanupLocalRateLimit().catch((error) => {
|
||||
logger.error("Error during rate limit cleanup:", error);
|
||||
});
|
||||
}, 60000); // Run cleanup every minute
|
||||
|
||||
this.forceSyncInterval = setInterval(() => {
|
||||
this.forceSyncAllPendingData().catch((error) => {
|
||||
logger.error("Error during force sync:", error);
|
||||
});
|
||||
}, REDIS_SYNC_FORCE_INTERVAL);
|
||||
}
|
||||
|
||||
// Redis keys
|
||||
private getRateLimitKey(clientId: string): string {
|
||||
return `ratelimit:${clientId}`;
|
||||
}
|
||||
|
||||
private getMessageTypeRateLimitKey(clientId: string, messageType: string): string {
|
||||
return `ratelimit:${clientId}:${messageType}`;
|
||||
}
|
||||
|
||||
// Helper function to sync local rate limit data to Redis
|
||||
private async syncRateLimitToRedis(
|
||||
clientId: string,
|
||||
tracker: RateLimitTracker
|
||||
): Promise<void> {
|
||||
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
|
||||
|
||||
try {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const globalKey = this.getRateLimitKey(clientId);
|
||||
|
||||
// Get current value and add pending count
|
||||
const currentValue = await redisManager.hget(
|
||||
globalKey,
|
||||
currentTime.toString()
|
||||
);
|
||||
const newValue = (
|
||||
parseInt(currentValue || "0") + tracker.pendingCount
|
||||
).toString();
|
||||
await redisManager.hset(globalKey, currentTime.toString(), newValue);
|
||||
|
||||
// Set TTL using the client directly
|
||||
if (redisManager.getClient()) {
|
||||
await redisManager
|
||||
.getClient()
|
||||
.expire(globalKey, RATE_LIMIT_WINDOW + 10);
|
||||
}
|
||||
|
||||
// Update tracking
|
||||
tracker.lastSyncedCount = tracker.count;
|
||||
tracker.pendingCount = 0;
|
||||
|
||||
logger.debug(`Synced global rate limit to Redis for client ${clientId}`);
|
||||
} catch (error) {
|
||||
logger.error("Failed to sync global rate limit to Redis:", error);
|
||||
}
|
||||
}
|
||||
|
||||
private async syncMessageTypeRateLimitToRedis(
|
||||
clientId: string,
|
||||
messageType: string,
|
||||
tracker: RateLimitTracker
|
||||
): Promise<void> {
|
||||
if (!redisManager.isRedisEnabled() || tracker.pendingCount === 0) return;
|
||||
|
||||
try {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
|
||||
|
||||
// Get current value and add pending count
|
||||
const currentValue = await redisManager.hget(
|
||||
messageTypeKey,
|
||||
currentTime.toString()
|
||||
);
|
||||
const newValue = (
|
||||
parseInt(currentValue || "0") + tracker.pendingCount
|
||||
).toString();
|
||||
await redisManager.hset(
|
||||
messageTypeKey,
|
||||
currentTime.toString(),
|
||||
newValue
|
||||
);
|
||||
|
||||
// Set TTL using the client directly
|
||||
if (redisManager.getClient()) {
|
||||
await redisManager
|
||||
.getClient()
|
||||
.expire(messageTypeKey, RATE_LIMIT_WINDOW + 10);
|
||||
}
|
||||
|
||||
// Update tracking
|
||||
tracker.lastSyncedCount = tracker.count;
|
||||
tracker.pendingCount = 0;
|
||||
|
||||
logger.debug(
|
||||
`Synced message type rate limit to Redis for client ${clientId}, type ${messageType}`
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to sync message type rate limit to Redis:", error);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize local tracker from Redis data
|
||||
private async initializeLocalTracker(clientId: string): Promise<RateLimitTracker> {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const windowStart = currentTime - RATE_LIMIT_WINDOW;
|
||||
|
||||
if (!redisManager.isRedisEnabled()) {
|
||||
return {
|
||||
count: 0,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: 0
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const globalKey = this.getRateLimitKey(clientId);
|
||||
const globalRateLimitData = await redisManager.hgetall(globalKey);
|
||||
|
||||
let count = 0;
|
||||
for (const [timestamp, countStr] of Object.entries(globalRateLimitData)) {
|
||||
const time = parseInt(timestamp);
|
||||
if (time >= windowStart) {
|
||||
count += parseInt(countStr);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
count,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: count
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error("Failed to initialize global tracker from Redis:", error);
|
||||
return {
|
||||
count: 0,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: 0
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private async initializeMessageTypeTracker(
|
||||
clientId: string,
|
||||
messageType: string
|
||||
): Promise<RateLimitTracker> {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const windowStart = currentTime - RATE_LIMIT_WINDOW;
|
||||
|
||||
if (!redisManager.isRedisEnabled()) {
|
||||
return {
|
||||
count: 0,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: 0
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const messageTypeKey = this.getMessageTypeRateLimitKey(clientId, messageType);
|
||||
const messageTypeRateLimitData = await redisManager.hgetall(messageTypeKey);
|
||||
|
||||
let count = 0;
|
||||
for (const [timestamp, countStr] of Object.entries(messageTypeRateLimitData)) {
|
||||
const time = parseInt(timestamp);
|
||||
if (time >= windowStart) {
|
||||
count += parseInt(countStr);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
count,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: count
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error("Failed to initialize message type tracker from Redis:", error);
|
||||
return {
|
||||
count: 0,
|
||||
windowStart: currentTime,
|
||||
pendingCount: 0,
|
||||
lastSyncedCount: 0
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Main rate limiting function
|
||||
async checkRateLimit(
|
||||
clientId: string,
|
||||
messageType?: string,
|
||||
maxRequests: number = RATE_LIMIT_MAX_REQUESTS,
|
||||
messageTypeLimit: number = RATE_LIMIT_PER_MESSAGE_TYPE,
|
||||
windowMs: number = RATE_LIMIT_WINDOW * 1000
|
||||
): Promise<RateLimitResult> {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const windowStart = currentTime - Math.floor(windowMs / 1000);
|
||||
|
||||
// Check global rate limit
|
||||
let globalTracker = this.localRateLimitTracker.get(clientId);
|
||||
|
||||
if (!globalTracker || globalTracker.windowStart < windowStart) {
|
||||
// New window or first request - initialize from Redis if available
|
||||
globalTracker = await this.initializeLocalTracker(clientId);
|
||||
globalTracker.windowStart = currentTime;
|
||||
this.localRateLimitTracker.set(clientId, globalTracker);
|
||||
}
|
||||
|
||||
// Increment global counters
|
||||
globalTracker.count++;
|
||||
globalTracker.pendingCount++;
|
||||
this.localRateLimitTracker.set(clientId, globalTracker);
|
||||
|
||||
// Check if global limit would be exceeded
|
||||
if (globalTracker.count >= maxRequests) {
|
||||
return {
|
||||
isLimited: true,
|
||||
reason: "global",
|
||||
totalHits: globalTracker.count,
|
||||
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
|
||||
};
|
||||
}
|
||||
|
||||
// Sync to Redis if threshold reached
|
||||
if (globalTracker.pendingCount >= REDIS_SYNC_THRESHOLD) {
|
||||
this.syncRateLimitToRedis(clientId, globalTracker);
|
||||
}
|
||||
|
||||
// Check message type specific rate limit if messageType is provided
|
||||
if (messageType) {
|
||||
const messageTypeKey = `${clientId}:${messageType}`;
|
||||
let messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
|
||||
|
||||
if (!messageTypeTracker || messageTypeTracker.windowStart < windowStart) {
|
||||
// New window or first request for this message type - initialize from Redis if available
|
||||
messageTypeTracker = await this.initializeMessageTypeTracker(clientId, messageType);
|
||||
messageTypeTracker.windowStart = currentTime;
|
||||
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
|
||||
}
|
||||
|
||||
// Increment message type counters
|
||||
messageTypeTracker.count++;
|
||||
messageTypeTracker.pendingCount++;
|
||||
this.localMessageTypeRateLimitTracker.set(messageTypeKey, messageTypeTracker);
|
||||
|
||||
// Check if message type limit would be exceeded
|
||||
if (messageTypeTracker.count >= messageTypeLimit) {
|
||||
return {
|
||||
isLimited: true,
|
||||
reason: `message_type:${messageType}`,
|
||||
totalHits: messageTypeTracker.count,
|
||||
resetTime: new Date((messageTypeTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
|
||||
};
|
||||
}
|
||||
|
||||
// Sync to Redis if threshold reached
|
||||
if (messageTypeTracker.pendingCount >= REDIS_SYNC_THRESHOLD) {
|
||||
this.syncMessageTypeRateLimitToRedis(clientId, messageType, messageTypeTracker);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isLimited: false,
|
||||
totalHits: globalTracker.count,
|
||||
resetTime: new Date((globalTracker.windowStart + Math.floor(windowMs / 1000)) * 1000)
|
||||
};
|
||||
}
|
||||
|
||||
// Decrement function for skipSuccessfulRequests/skipFailedRequests functionality
|
||||
async decrementRateLimit(clientId: string, messageType?: string): Promise<void> {
|
||||
// Decrement global counter
|
||||
const globalTracker = this.localRateLimitTracker.get(clientId);
|
||||
if (globalTracker && globalTracker.count > 0) {
|
||||
globalTracker.count--;
|
||||
// We need to account for this in pending count to sync correctly
|
||||
globalTracker.pendingCount--;
|
||||
}
|
||||
|
||||
// Decrement message type counter if provided
|
||||
if (messageType) {
|
||||
const messageTypeKey = `${clientId}:${messageType}`;
|
||||
const messageTypeTracker = this.localMessageTypeRateLimitTracker.get(messageTypeKey);
|
||||
if (messageTypeTracker && messageTypeTracker.count > 0) {
|
||||
messageTypeTracker.count--;
|
||||
messageTypeTracker.pendingCount--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset key function
|
||||
async resetKey(clientId: string): Promise<void> {
|
||||
// Remove from local tracking
|
||||
this.localRateLimitTracker.delete(clientId);
|
||||
|
||||
// Remove all message type entries for this client
|
||||
for (const [key] of this.localMessageTypeRateLimitTracker) {
|
||||
if (key.startsWith(`${clientId}:`)) {
|
||||
this.localMessageTypeRateLimitTracker.delete(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from Redis if enabled
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
const globalKey = this.getRateLimitKey(clientId);
|
||||
await redisManager.del(globalKey);
|
||||
|
||||
// Get all message type keys for this client and delete them
|
||||
const client = redisManager.getClient();
|
||||
if (client) {
|
||||
const messageTypeKeys = await client.keys(`ratelimit:${clientId}:*`);
|
||||
if (messageTypeKeys.length > 0) {
|
||||
await Promise.all(messageTypeKeys.map(key => redisManager.del(key)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup old local rate limit entries and force sync pending data
|
||||
private async cleanupLocalRateLimit(): Promise<void> {
|
||||
const currentTime = Math.floor(Date.now() / 1000);
|
||||
const windowStart = currentTime - RATE_LIMIT_WINDOW;
|
||||
|
||||
// Clean up global rate limit tracking and sync pending data
|
||||
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
|
||||
if (tracker.windowStart < windowStart) {
|
||||
// Sync any pending data before cleanup
|
||||
if (tracker.pendingCount > 0) {
|
||||
await this.syncRateLimitToRedis(clientId, tracker);
|
||||
}
|
||||
this.localRateLimitTracker.delete(clientId);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up message type rate limit tracking and sync pending data
|
||||
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
|
||||
if (tracker.windowStart < windowStart) {
|
||||
// Sync any pending data before cleanup
|
||||
if (tracker.pendingCount > 0) {
|
||||
const [clientId, messageType] = key.split(":", 2);
|
||||
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
|
||||
}
|
||||
this.localMessageTypeRateLimitTracker.delete(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Force sync all pending rate limit data to Redis
|
||||
async forceSyncAllPendingData(): Promise<void> {
|
||||
if (!redisManager.isRedisEnabled()) return;
|
||||
|
||||
logger.debug("Force syncing all pending rate limit data to Redis...");
|
||||
|
||||
// Sync all pending global rate limits
|
||||
for (const [clientId, tracker] of this.localRateLimitTracker.entries()) {
|
||||
if (tracker.pendingCount > 0) {
|
||||
await this.syncRateLimitToRedis(clientId, tracker);
|
||||
}
|
||||
}
|
||||
|
||||
// Sync all pending message type rate limits
|
||||
for (const [key, tracker] of this.localMessageTypeRateLimitTracker.entries()) {
|
||||
if (tracker.pendingCount > 0) {
|
||||
const [clientId, messageType] = key.split(":", 2);
|
||||
await this.syncMessageTypeRateLimitToRedis(clientId, messageType, tracker);
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug("Completed force sync of pending rate limit data");
|
||||
}
|
||||
|
||||
// Cleanup function for graceful shutdown
|
||||
async cleanup(): Promise<void> {
|
||||
if (build == "oss") {
|
||||
return;
|
||||
}
|
||||
|
||||
// Clear intervals
|
||||
if (this.cleanupInterval) {
|
||||
clearInterval(this.cleanupInterval);
|
||||
}
|
||||
if (this.forceSyncInterval) {
|
||||
clearInterval(this.forceSyncInterval);
|
||||
}
|
||||
|
||||
// Force sync all pending data
|
||||
await this.forceSyncAllPendingData();
|
||||
|
||||
// Clear local data
|
||||
this.localRateLimitTracker.clear();
|
||||
this.localMessageTypeRateLimitTracker.clear();
|
||||
|
||||
logger.info("Rate limit service cleanup completed");
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const rateLimitService = new RateLimitService();
|
||||
|
||||
// Handle process termination
|
||||
process.on("SIGTERM", () => rateLimitService.cleanup());
|
||||
process.on("SIGINT", () => rateLimitService.cleanup());
|
||||
782
server/db/private/redis.ts
Normal file
782
server/db/private/redis.ts
Normal file
@@ -0,0 +1,782 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import Redis, { RedisOptions } from "ioredis";
|
||||
import logger from "@server/logger";
|
||||
import config from "@server/lib/config";
|
||||
import { build } from "@server/build";
|
||||
|
||||
class RedisManager {
|
||||
public client: Redis | null = null;
|
||||
private writeClient: Redis | null = null; // Master for writes
|
||||
private readClient: Redis | null = null; // Replica for reads
|
||||
private subscriber: Redis | null = null;
|
||||
private publisher: Redis | null = null;
|
||||
private isEnabled: boolean = false;
|
||||
private isHealthy: boolean = true;
|
||||
private isWriteHealthy: boolean = true;
|
||||
private isReadHealthy: boolean = true;
|
||||
private lastHealthCheck: number = 0;
|
||||
private healthCheckInterval: number = 30000; // 30 seconds
|
||||
private connectionTimeout: number = 15000; // 15 seconds
|
||||
private commandTimeout: number = 15000; // 15 seconds
|
||||
private hasReplicas: boolean = false;
|
||||
private maxRetries: number = 3;
|
||||
private baseRetryDelay: number = 100; // 100ms
|
||||
private maxRetryDelay: number = 2000; // 2 seconds
|
||||
private backoffMultiplier: number = 2;
|
||||
private subscribers: Map<
|
||||
string,
|
||||
Set<(channel: string, message: string) => void>
|
||||
> = new Map();
|
||||
private reconnectionCallbacks: Set<() => Promise<void>> = new Set();
|
||||
|
||||
constructor() {
|
||||
if (build == "oss") {
|
||||
this.isEnabled = false;
|
||||
return
|
||||
}
|
||||
this.isEnabled = config.getRawPrivateConfig().flags?.enable_redis || false;
|
||||
if (this.isEnabled) {
|
||||
this.initializeClients();
|
||||
}
|
||||
}
|
||||
|
||||
// Register callback to be called when Redis reconnects
|
||||
public onReconnection(callback: () => Promise<void>): void {
|
||||
this.reconnectionCallbacks.add(callback);
|
||||
}
|
||||
|
||||
// Unregister reconnection callback
|
||||
public offReconnection(callback: () => Promise<void>): void {
|
||||
this.reconnectionCallbacks.delete(callback);
|
||||
}
|
||||
|
||||
private async triggerReconnectionCallbacks(): Promise<void> {
|
||||
logger.info(`Triggering ${this.reconnectionCallbacks.size} reconnection callbacks`);
|
||||
|
||||
const promises = Array.from(this.reconnectionCallbacks).map(async (callback) => {
|
||||
try {
|
||||
await callback();
|
||||
} catch (error) {
|
||||
logger.error("Error in reconnection callback:", error);
|
||||
}
|
||||
});
|
||||
|
||||
await Promise.allSettled(promises);
|
||||
}
|
||||
|
||||
private async resubscribeToChannels(): Promise<void> {
|
||||
if (!this.subscriber || this.subscribers.size === 0) return;
|
||||
|
||||
logger.info(`Re-subscribing to ${this.subscribers.size} channels after Redis reconnection`);
|
||||
|
||||
try {
|
||||
const channels = Array.from(this.subscribers.keys());
|
||||
if (channels.length > 0) {
|
||||
await this.subscriber.subscribe(...channels);
|
||||
logger.info(`Successfully re-subscribed to channels: ${channels.join(', ')}`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Failed to re-subscribe to channels:", error);
|
||||
}
|
||||
}
|
||||
|
||||
private getRedisConfig(): RedisOptions {
|
||||
const redisConfig = config.getRawPrivateConfig().redis!;
|
||||
const opts: RedisOptions = {
|
||||
host: redisConfig.host!,
|
||||
port: redisConfig.port!,
|
||||
password: redisConfig.password,
|
||||
db: redisConfig.db,
|
||||
// tls: {
|
||||
// rejectUnauthorized:
|
||||
// redisConfig.tls?.reject_unauthorized || false
|
||||
// }
|
||||
};
|
||||
return opts;
|
||||
}
|
||||
|
||||
private getReplicaRedisConfig(): RedisOptions | null {
|
||||
const redisConfig = config.getRawPrivateConfig().redis!;
|
||||
if (!redisConfig.replicas || redisConfig.replicas.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Use the first replica for simplicity
|
||||
// In production, you might want to implement load balancing across replicas
|
||||
const replica = redisConfig.replicas[0];
|
||||
const opts: RedisOptions = {
|
||||
host: replica.host!,
|
||||
port: replica.port!,
|
||||
password: replica.password,
|
||||
db: replica.db || redisConfig.db,
|
||||
// tls: {
|
||||
// rejectUnauthorized:
|
||||
// replica.tls?.reject_unauthorized || false
|
||||
// }
|
||||
};
|
||||
return opts;
|
||||
}
|
||||
|
||||
// Add reconnection logic in initializeClients
|
||||
private initializeClients(): void {
|
||||
const masterConfig = this.getRedisConfig();
|
||||
const replicaConfig = this.getReplicaRedisConfig();
|
||||
|
||||
this.hasReplicas = replicaConfig !== null;
|
||||
|
||||
try {
|
||||
// Initialize master connection for writes
|
||||
this.writeClient = new Redis({
|
||||
...masterConfig,
|
||||
enableReadyCheck: false,
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
});
|
||||
|
||||
// Initialize replica connection for reads (if available)
|
||||
if (this.hasReplicas) {
|
||||
this.readClient = new Redis({
|
||||
...replicaConfig!,
|
||||
enableReadyCheck: false,
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
});
|
||||
} else {
|
||||
// Fallback to master for reads if no replicas
|
||||
this.readClient = this.writeClient;
|
||||
}
|
||||
|
||||
// Backward compatibility - point to write client
|
||||
this.client = this.writeClient;
|
||||
|
||||
// Publisher uses master (writes)
|
||||
this.publisher = new Redis({
|
||||
...masterConfig,
|
||||
enableReadyCheck: false,
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
});
|
||||
|
||||
// Subscriber uses replica if available (reads)
|
||||
this.subscriber = new Redis({
|
||||
...(this.hasReplicas ? replicaConfig! : masterConfig),
|
||||
enableReadyCheck: false,
|
||||
maxRetriesPerRequest: 3,
|
||||
keepAlive: 30000,
|
||||
connectTimeout: this.connectionTimeout,
|
||||
commandTimeout: this.commandTimeout,
|
||||
});
|
||||
|
||||
// Add reconnection handlers for write client
|
||||
this.writeClient.on("error", (err) => {
|
||||
logger.error("Redis write client error:", err);
|
||||
this.isWriteHealthy = false;
|
||||
this.isHealthy = false;
|
||||
});
|
||||
|
||||
this.writeClient.on("reconnecting", () => {
|
||||
logger.info("Redis write client reconnecting...");
|
||||
this.isWriteHealthy = false;
|
||||
this.isHealthy = false;
|
||||
});
|
||||
|
||||
this.writeClient.on("ready", () => {
|
||||
logger.info("Redis write client ready");
|
||||
this.isWriteHealthy = true;
|
||||
this.updateOverallHealth();
|
||||
|
||||
// Trigger reconnection callbacks when Redis comes back online
|
||||
if (this.isHealthy) {
|
||||
this.triggerReconnectionCallbacks().catch(error => {
|
||||
logger.error("Error triggering reconnection callbacks:", error);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
this.writeClient.on("connect", () => {
|
||||
logger.info("Redis write client connected");
|
||||
});
|
||||
|
||||
// Add reconnection handlers for read client (if different from write)
|
||||
if (this.hasReplicas && this.readClient !== this.writeClient) {
|
||||
this.readClient.on("error", (err) => {
|
||||
logger.error("Redis read client error:", err);
|
||||
this.isReadHealthy = false;
|
||||
this.updateOverallHealth();
|
||||
});
|
||||
|
||||
this.readClient.on("reconnecting", () => {
|
||||
logger.info("Redis read client reconnecting...");
|
||||
this.isReadHealthy = false;
|
||||
this.updateOverallHealth();
|
||||
});
|
||||
|
||||
this.readClient.on("ready", () => {
|
||||
logger.info("Redis read client ready");
|
||||
this.isReadHealthy = true;
|
||||
this.updateOverallHealth();
|
||||
|
||||
// Trigger reconnection callbacks when Redis comes back online
|
||||
if (this.isHealthy) {
|
||||
this.triggerReconnectionCallbacks().catch(error => {
|
||||
logger.error("Error triggering reconnection callbacks:", error);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
this.readClient.on("connect", () => {
|
||||
logger.info("Redis read client connected");
|
||||
});
|
||||
} else {
|
||||
// If using same client for reads and writes
|
||||
this.isReadHealthy = this.isWriteHealthy;
|
||||
}
|
||||
|
||||
this.publisher.on("error", (err) => {
|
||||
logger.error("Redis publisher error:", err);
|
||||
});
|
||||
|
||||
this.publisher.on("ready", () => {
|
||||
logger.info("Redis publisher ready");
|
||||
});
|
||||
|
||||
this.publisher.on("connect", () => {
|
||||
logger.info("Redis publisher connected");
|
||||
});
|
||||
|
||||
this.subscriber.on("error", (err) => {
|
||||
logger.error("Redis subscriber error:", err);
|
||||
});
|
||||
|
||||
this.subscriber.on("ready", () => {
|
||||
logger.info("Redis subscriber ready");
|
||||
// Re-subscribe to all channels after reconnection
|
||||
this.resubscribeToChannels().catch((error: any) => {
|
||||
logger.error("Error re-subscribing to channels:", error);
|
||||
});
|
||||
});
|
||||
|
||||
this.subscriber.on("connect", () => {
|
||||
logger.info("Redis subscriber connected");
|
||||
});
|
||||
|
||||
// Set up message handler for subscriber
|
||||
this.subscriber.on(
|
||||
"message",
|
||||
(channel: string, message: string) => {
|
||||
const channelSubscribers = this.subscribers.get(channel);
|
||||
if (channelSubscribers) {
|
||||
channelSubscribers.forEach((callback) => {
|
||||
try {
|
||||
callback(channel, message);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error in subscriber callback for channel ${channel}:`,
|
||||
error
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
const setupMessage = this.hasReplicas
|
||||
? "Redis clients initialized successfully with replica support"
|
||||
: "Redis clients initialized successfully (single instance)";
|
||||
logger.info(setupMessage);
|
||||
|
||||
// Start periodic health monitoring
|
||||
this.startHealthMonitoring();
|
||||
} catch (error) {
|
||||
logger.error("Failed to initialize Redis clients:", error);
|
||||
this.isEnabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
private updateOverallHealth(): void {
|
||||
// Overall health is true if write is healthy and (read is healthy OR we don't have replicas)
|
||||
this.isHealthy = this.isWriteHealthy && (this.isReadHealthy || !this.hasReplicas);
|
||||
}
|
||||
|
||||
private async executeWithRetry<T>(
|
||||
operation: () => Promise<T>,
|
||||
operationName: string,
|
||||
fallbackOperation?: () => Promise<T>
|
||||
): Promise<T> {
|
||||
let lastError: Error | null = null;
|
||||
|
||||
for (let attempt = 0; attempt <= this.maxRetries; attempt++) {
|
||||
try {
|
||||
return await operation();
|
||||
} catch (error) {
|
||||
lastError = error as Error;
|
||||
|
||||
// If this is the last attempt, try fallback if available
|
||||
if (attempt === this.maxRetries && fallbackOperation) {
|
||||
try {
|
||||
logger.warn(`${operationName} primary operation failed, trying fallback`);
|
||||
return await fallbackOperation();
|
||||
} catch (fallbackError) {
|
||||
logger.error(`${operationName} fallback also failed:`, fallbackError);
|
||||
throw lastError;
|
||||
}
|
||||
}
|
||||
|
||||
// Don't retry on the last attempt
|
||||
if (attempt === this.maxRetries) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Calculate delay with exponential backoff
|
||||
const delay = Math.min(
|
||||
this.baseRetryDelay * Math.pow(this.backoffMultiplier, attempt),
|
||||
this.maxRetryDelay
|
||||
);
|
||||
|
||||
logger.warn(`${operationName} failed (attempt ${attempt + 1}/${this.maxRetries + 1}), retrying in ${delay}ms:`, error);
|
||||
|
||||
// Wait before retrying
|
||||
await new Promise(resolve => setTimeout(resolve, delay));
|
||||
}
|
||||
}
|
||||
|
||||
logger.error(`${operationName} failed after ${this.maxRetries + 1} attempts:`, lastError);
|
||||
throw lastError;
|
||||
}
|
||||
|
||||
private startHealthMonitoring(): void {
|
||||
if (!this.isEnabled) return;
|
||||
|
||||
// Check health every 30 seconds
|
||||
setInterval(async () => {
|
||||
try {
|
||||
await this.checkRedisHealth();
|
||||
} catch (error) {
|
||||
logger.error("Error during Redis health monitoring:", error);
|
||||
}
|
||||
}, this.healthCheckInterval);
|
||||
}
|
||||
|
||||
public isRedisEnabled(): boolean {
|
||||
return this.isEnabled && this.client !== null && this.isHealthy;
|
||||
}
|
||||
|
||||
private async checkRedisHealth(): Promise<boolean> {
|
||||
const now = Date.now();
|
||||
|
||||
// Only check health every 30 seconds
|
||||
if (now - this.lastHealthCheck < this.healthCheckInterval) {
|
||||
return this.isHealthy;
|
||||
}
|
||||
|
||||
this.lastHealthCheck = now;
|
||||
|
||||
if (!this.writeClient) {
|
||||
this.isHealthy = false;
|
||||
this.isWriteHealthy = false;
|
||||
this.isReadHealthy = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
// Check write client (master) health
|
||||
await Promise.race([
|
||||
this.writeClient.ping(),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Write client health check timeout')), 2000)
|
||||
)
|
||||
]);
|
||||
this.isWriteHealthy = true;
|
||||
|
||||
// Check read client health if it's different from write client
|
||||
if (this.hasReplicas && this.readClient && this.readClient !== this.writeClient) {
|
||||
try {
|
||||
await Promise.race([
|
||||
this.readClient.ping(),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Read client health check timeout')), 2000)
|
||||
)
|
||||
]);
|
||||
this.isReadHealthy = true;
|
||||
} catch (error) {
|
||||
logger.error("Redis read client health check failed:", error);
|
||||
this.isReadHealthy = false;
|
||||
}
|
||||
} else {
|
||||
this.isReadHealthy = this.isWriteHealthy;
|
||||
}
|
||||
|
||||
this.updateOverallHealth();
|
||||
return this.isHealthy;
|
||||
} catch (error) {
|
||||
logger.error("Redis write client health check failed:", error);
|
||||
this.isWriteHealthy = false;
|
||||
this.isReadHealthy = false; // If write fails, consider read as failed too for safety
|
||||
this.isHealthy = false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public getClient(): Redis {
|
||||
return this.client!;
|
||||
}
|
||||
|
||||
public getWriteClient(): Redis | null {
|
||||
return this.writeClient;
|
||||
}
|
||||
|
||||
public getReadClient(): Redis | null {
|
||||
return this.readClient;
|
||||
}
|
||||
|
||||
public hasReplicaSupport(): boolean {
|
||||
return this.hasReplicas;
|
||||
}
|
||||
|
||||
public getHealthStatus(): {
|
||||
isEnabled: boolean;
|
||||
isHealthy: boolean;
|
||||
isWriteHealthy: boolean;
|
||||
isReadHealthy: boolean;
|
||||
hasReplicas: boolean;
|
||||
} {
|
||||
return {
|
||||
isEnabled: this.isEnabled,
|
||||
isHealthy: this.isHealthy,
|
||||
isWriteHealthy: this.isWriteHealthy,
|
||||
isReadHealthy: this.isReadHealthy,
|
||||
hasReplicas: this.hasReplicas
|
||||
};
|
||||
}
|
||||
|
||||
public async set(
|
||||
key: string,
|
||||
value: string,
|
||||
ttl?: number
|
||||
): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
async () => {
|
||||
if (ttl) {
|
||||
await this.writeClient!.setex(key, ttl, value);
|
||||
} else {
|
||||
await this.writeClient!.set(key, value);
|
||||
}
|
||||
},
|
||||
"Redis SET"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis SET error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async get(key: string): Promise<string | null> {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return null;
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.get(key)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.get(key),
|
||||
"Redis GET",
|
||||
fallbackOperation
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Redis GET error:", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public async del(key: string): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
() => this.writeClient!.del(key),
|
||||
"Redis DEL"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis DEL error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async sadd(key: string, member: string): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
() => this.writeClient!.sadd(key, member),
|
||||
"Redis SADD"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis SADD error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async srem(key: string, member: string): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
() => this.writeClient!.srem(key, member),
|
||||
"Redis SREM"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis SREM error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async smembers(key: string): Promise<string[]> {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return [];
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.smembers(key)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.smembers(key),
|
||||
"Redis SMEMBERS",
|
||||
fallbackOperation
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Redis SMEMBERS error:", error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
public async hset(
|
||||
key: string,
|
||||
field: string,
|
||||
value: string
|
||||
): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
() => this.writeClient!.hset(key, field, value),
|
||||
"Redis HSET"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis HSET error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async hget(key: string, field: string): Promise<string | null> {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return null;
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.hget(key, field)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.hget(key, field),
|
||||
"Redis HGET",
|
||||
fallbackOperation
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Redis HGET error:", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public async hdel(key: string, field: string): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.writeClient) return false;
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
() => this.writeClient!.hdel(key, field),
|
||||
"Redis HDEL"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis HDEL error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async hgetall(key: string): Promise<Record<string, string>> {
|
||||
if (!this.isRedisEnabled() || !this.readClient) return {};
|
||||
|
||||
try {
|
||||
const fallbackOperation = (this.hasReplicas && this.writeClient && this.isWriteHealthy)
|
||||
? () => this.writeClient!.hgetall(key)
|
||||
: undefined;
|
||||
|
||||
return await this.executeWithRetry(
|
||||
() => this.readClient!.hgetall(key),
|
||||
"Redis HGETALL",
|
||||
fallbackOperation
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Redis HGETALL error:", error);
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
public async publish(channel: string, message: string): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.publisher) return false;
|
||||
|
||||
// Quick health check before attempting to publish
|
||||
const isHealthy = await this.checkRedisHealth();
|
||||
if (!isHealthy) {
|
||||
logger.warn("Skipping Redis publish due to unhealthy connection");
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
await this.executeWithRetry(
|
||||
async () => {
|
||||
// Add timeout to prevent hanging
|
||||
return Promise.race([
|
||||
this.publisher!.publish(channel, message),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Redis publish timeout')), 3000)
|
||||
)
|
||||
]);
|
||||
},
|
||||
"Redis PUBLISH"
|
||||
);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis PUBLISH error:", error);
|
||||
this.isHealthy = false; // Mark as unhealthy on error
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async subscribe(
|
||||
channel: string,
|
||||
callback: (channel: string, message: string) => void
|
||||
): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.subscriber) return false;
|
||||
|
||||
try {
|
||||
// Add callback to subscribers map
|
||||
if (!this.subscribers.has(channel)) {
|
||||
this.subscribers.set(channel, new Set());
|
||||
// Only subscribe to the channel if it's the first subscriber
|
||||
await this.executeWithRetry(
|
||||
async () => {
|
||||
return Promise.race([
|
||||
this.subscriber!.subscribe(channel),
|
||||
new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Redis subscribe timeout')), 5000)
|
||||
)
|
||||
]);
|
||||
},
|
||||
"Redis SUBSCRIBE"
|
||||
);
|
||||
}
|
||||
|
||||
this.subscribers.get(channel)!.add(callback);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis SUBSCRIBE error:", error);
|
||||
this.isHealthy = false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async unsubscribe(
|
||||
channel: string,
|
||||
callback?: (channel: string, message: string) => void
|
||||
): Promise<boolean> {
|
||||
if (!this.isRedisEnabled() || !this.subscriber) return false;
|
||||
|
||||
try {
|
||||
const channelSubscribers = this.subscribers.get(channel);
|
||||
if (!channelSubscribers) return true;
|
||||
|
||||
if (callback) {
|
||||
// Remove specific callback
|
||||
channelSubscribers.delete(callback);
|
||||
if (channelSubscribers.size === 0) {
|
||||
this.subscribers.delete(channel);
|
||||
await this.executeWithRetry(
|
||||
() => this.subscriber!.unsubscribe(channel),
|
||||
"Redis UNSUBSCRIBE"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Remove all callbacks for this channel
|
||||
this.subscribers.delete(channel);
|
||||
await this.executeWithRetry(
|
||||
() => this.subscriber!.unsubscribe(channel),
|
||||
"Redis UNSUBSCRIBE"
|
||||
);
|
||||
}
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error("Redis UNSUBSCRIBE error:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public async disconnect(): Promise<void> {
|
||||
try {
|
||||
if (this.client) {
|
||||
await this.client.quit();
|
||||
this.client = null;
|
||||
}
|
||||
if (this.writeClient) {
|
||||
await this.writeClient.quit();
|
||||
this.writeClient = null;
|
||||
}
|
||||
if (this.readClient && this.readClient !== this.writeClient) {
|
||||
await this.readClient.quit();
|
||||
this.readClient = null;
|
||||
}
|
||||
if (this.publisher) {
|
||||
await this.publisher.quit();
|
||||
this.publisher = null;
|
||||
}
|
||||
if (this.subscriber) {
|
||||
await this.subscriber.quit();
|
||||
this.subscriber = null;
|
||||
}
|
||||
this.subscribers.clear();
|
||||
logger.info("Redis clients disconnected");
|
||||
} catch (error) {
|
||||
logger.error("Error disconnecting Redis clients:", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const redisManager = new RedisManager();
|
||||
export const redis = redisManager.getClient();
|
||||
export default redisManager;
|
||||
223
server/db/private/redisStore.ts
Normal file
223
server/db/private/redisStore.ts
Normal file
@@ -0,0 +1,223 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Store, Options, IncrementResponse } from 'express-rate-limit';
|
||||
import { rateLimitService } from './rateLimit';
|
||||
import logger from '@server/logger';
|
||||
|
||||
/**
|
||||
* A Redis-backed rate limiting store for express-rate-limit that optimizes
|
||||
* for local read performance and batched writes to Redis.
|
||||
*
|
||||
* This store uses the same optimized rate limiting logic as the WebSocket
|
||||
* implementation, providing:
|
||||
* - Local caching for fast reads
|
||||
* - Batched writes to Redis to reduce load
|
||||
* - Automatic cleanup of expired entries
|
||||
* - Graceful fallback when Redis is unavailable
|
||||
*/
|
||||
export default class RedisStore implements Store {
|
||||
/**
|
||||
* The duration of time before which all hit counts are reset (in milliseconds).
|
||||
*/
|
||||
windowMs!: number;
|
||||
|
||||
/**
|
||||
* Maximum number of requests allowed within the window.
|
||||
*/
|
||||
max!: number;
|
||||
|
||||
/**
|
||||
* Optional prefix for Redis keys to avoid collisions.
|
||||
*/
|
||||
prefix: string;
|
||||
|
||||
/**
|
||||
* Whether to skip incrementing on failed requests.
|
||||
*/
|
||||
skipFailedRequests: boolean;
|
||||
|
||||
/**
|
||||
* Whether to skip incrementing on successful requests.
|
||||
*/
|
||||
skipSuccessfulRequests: boolean;
|
||||
|
||||
/**
|
||||
* @constructor for RedisStore.
|
||||
*
|
||||
* @param options - Configuration options for the store.
|
||||
*/
|
||||
constructor(options: {
|
||||
prefix?: string;
|
||||
skipFailedRequests?: boolean;
|
||||
skipSuccessfulRequests?: boolean;
|
||||
} = {}) {
|
||||
this.prefix = options.prefix || 'express-rate-limit';
|
||||
this.skipFailedRequests = options.skipFailedRequests || false;
|
||||
this.skipSuccessfulRequests = options.skipSuccessfulRequests || false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Method that actually initializes the store. Must be synchronous.
|
||||
*
|
||||
* @param options - The options used to setup express-rate-limit.
|
||||
*/
|
||||
init(options: Options): void {
|
||||
this.windowMs = options.windowMs;
|
||||
this.max = options.max as number;
|
||||
|
||||
// logger.debug(`RedisStore initialized with windowMs: ${this.windowMs}, max: ${this.max}, prefix: ${this.prefix}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to increment a client's hit counter.
|
||||
*
|
||||
* @param key - The identifier for a client (usually IP address).
|
||||
* @returns Promise resolving to the number of hits and reset time for that client.
|
||||
*/
|
||||
async increment(key: string): Promise<IncrementResponse> {
|
||||
try {
|
||||
const clientId = `${this.prefix}:${key}`;
|
||||
|
||||
const result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
undefined, // No message type for HTTP requests
|
||||
this.max,
|
||||
undefined, // No message type limit
|
||||
this.windowMs
|
||||
);
|
||||
|
||||
// logger.debug(`Incremented rate limit for key: ${key} with max: ${this.max}, totalHits: ${result.totalHits}`);
|
||||
|
||||
return {
|
||||
totalHits: result.totalHits || 1,
|
||||
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`RedisStore increment error for key ${key}:`, error);
|
||||
|
||||
// Return safe defaults on error to prevent blocking requests
|
||||
return {
|
||||
totalHits: 1,
|
||||
resetTime: new Date(Date.now() + this.windowMs)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to decrement a client's hit counter.
|
||||
* Used when skipSuccessfulRequests or skipFailedRequests is enabled.
|
||||
*
|
||||
* @param key - The identifier for a client.
|
||||
*/
|
||||
async decrement(key: string): Promise<void> {
|
||||
try {
|
||||
const clientId = `${this.prefix}:${key}`;
|
||||
await rateLimitService.decrementRateLimit(clientId);
|
||||
|
||||
// logger.debug(`Decremented rate limit for key: ${key}`);
|
||||
} catch (error) {
|
||||
logger.error(`RedisStore decrement error for key ${key}:`, error);
|
||||
// Don't throw - decrement failures shouldn't block requests
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to reset a client's hit counter.
|
||||
*
|
||||
* @param key - The identifier for a client.
|
||||
*/
|
||||
async resetKey(key: string): Promise<void> {
|
||||
try {
|
||||
const clientId = `${this.prefix}:${key}`;
|
||||
await rateLimitService.resetKey(clientId);
|
||||
|
||||
// logger.debug(`Reset rate limit for key: ${key}`);
|
||||
} catch (error) {
|
||||
logger.error(`RedisStore resetKey error for key ${key}:`, error);
|
||||
// Don't throw - reset failures shouldn't block requests
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to reset everyone's hit counter.
|
||||
*
|
||||
* This method is optional and is never called by express-rate-limit.
|
||||
* We implement it for completeness but it's not recommended for production use
|
||||
* as it could be expensive with large datasets.
|
||||
*/
|
||||
async resetAll(): Promise<void> {
|
||||
try {
|
||||
logger.warn('RedisStore resetAll called - this operation can be expensive');
|
||||
|
||||
// Force sync all pending data first
|
||||
await rateLimitService.forceSyncAllPendingData();
|
||||
|
||||
// Note: We don't actually implement full reset as it would require
|
||||
// scanning all Redis keys with our prefix, which could be expensive.
|
||||
// In production, it's better to let entries expire naturally.
|
||||
|
||||
logger.info('RedisStore resetAll completed (pending data synced)');
|
||||
} catch (error) {
|
||||
logger.error('RedisStore resetAll error:', error);
|
||||
// Don't throw - this is an optional method
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current hit count for a key without incrementing.
|
||||
* This is a custom method not part of the Store interface.
|
||||
*
|
||||
* @param key - The identifier for a client.
|
||||
* @returns Current hit count and reset time, or null if no data exists.
|
||||
*/
|
||||
async getHits(key: string): Promise<{ totalHits: number; resetTime: Date } | null> {
|
||||
try {
|
||||
const clientId = `${this.prefix}:${key}`;
|
||||
|
||||
// Use checkRateLimit with max + 1 to avoid actually incrementing
|
||||
// but still get the current count
|
||||
const result = await rateLimitService.checkRateLimit(
|
||||
clientId,
|
||||
undefined,
|
||||
this.max + 1000, // Set artificially high to avoid triggering limit
|
||||
undefined,
|
||||
this.windowMs
|
||||
);
|
||||
|
||||
// Decrement since we don't actually want to count this check
|
||||
await rateLimitService.decrementRateLimit(clientId);
|
||||
|
||||
return {
|
||||
totalHits: Math.max(0, (result.totalHits || 0) - 1), // Adjust for the decrement
|
||||
resetTime: result.resetTime || new Date(Date.now() + this.windowMs)
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`RedisStore getHits error for key ${key}:`, error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup method for graceful shutdown.
|
||||
* This is not part of the Store interface but is useful for cleanup.
|
||||
*/
|
||||
async shutdown(): Promise<void> {
|
||||
try {
|
||||
// The rateLimitService handles its own cleanup
|
||||
logger.info('RedisStore shutdown completed');
|
||||
} catch (error) {
|
||||
logger.error('RedisStore shutdown error:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import { db } from "@server/db";
|
||||
import { db, loginPage, LoginPage, loginPageOrg } from "@server/db";
|
||||
import {
|
||||
Resource,
|
||||
ResourcePassword,
|
||||
@@ -39,7 +39,10 @@ export async function getResourceByDomain(
|
||||
): Promise<ResourceWithAuth | null> {
|
||||
if (config.isManagedMode()) {
|
||||
try {
|
||||
const response = await axios.get(`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/resource/domain/${domain}`, await tokenManager.getAuthHeader());
|
||||
const response = await axios.get(
|
||||
`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/resource/domain/${domain}`,
|
||||
await tokenManager.getAuthHeader()
|
||||
);
|
||||
return response.data.data;
|
||||
} catch (error) {
|
||||
if (axios.isAxiosError(error)) {
|
||||
@@ -91,7 +94,10 @@ export async function getUserSessionWithUser(
|
||||
): Promise<UserSessionWithUser | null> {
|
||||
if (config.isManagedMode()) {
|
||||
try {
|
||||
const response = await axios.get(`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/session/${userSessionId}`, await tokenManager.getAuthHeader());
|
||||
const response = await axios.get(
|
||||
`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/session/${userSessionId}`,
|
||||
await tokenManager.getAuthHeader()
|
||||
);
|
||||
return response.data.data;
|
||||
} catch (error) {
|
||||
if (axios.isAxiosError(error)) {
|
||||
@@ -132,7 +138,10 @@ export async function getUserSessionWithUser(
|
||||
export async function getUserOrgRole(userId: string, orgId: string) {
|
||||
if (config.isManagedMode()) {
|
||||
try {
|
||||
const response = await axios.get(`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/user/${userId}/org/${orgId}/role`, await tokenManager.getAuthHeader());
|
||||
const response = await axios.get(
|
||||
`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/user/${userId}/org/${orgId}/role`,
|
||||
await tokenManager.getAuthHeader()
|
||||
);
|
||||
return response.data.data;
|
||||
} catch (error) {
|
||||
if (axios.isAxiosError(error)) {
|
||||
@@ -154,12 +163,7 @@ export async function getUserOrgRole(userId: string, orgId: string) {
|
||||
const userOrgRole = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.userId, userId),
|
||||
eq(userOrgs.orgId, orgId)
|
||||
)
|
||||
)
|
||||
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)))
|
||||
.limit(1);
|
||||
|
||||
return userOrgRole.length > 0 ? userOrgRole[0] : null;
|
||||
@@ -168,10 +172,16 @@ export async function getUserOrgRole(userId: string, orgId: string) {
|
||||
/**
|
||||
* Check if role has access to resource
|
||||
*/
|
||||
export async function getRoleResourceAccess(resourceId: number, roleId: number) {
|
||||
export async function getRoleResourceAccess(
|
||||
resourceId: number,
|
||||
roleId: number
|
||||
) {
|
||||
if (config.isManagedMode()) {
|
||||
try {
|
||||
const response = await axios.get(`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/role/${roleId}/resource/${resourceId}/access`, await tokenManager.getAuthHeader());
|
||||
const response = await axios.get(
|
||||
`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/role/${roleId}/resource/${resourceId}/access`,
|
||||
await tokenManager.getAuthHeader()
|
||||
);
|
||||
return response.data.data;
|
||||
} catch (error) {
|
||||
if (axios.isAxiosError(error)) {
|
||||
@@ -207,10 +217,16 @@ export async function getRoleResourceAccess(resourceId: number, roleId: number)
|
||||
/**
|
||||
* Check if user has direct access to resource
|
||||
*/
|
||||
export async function getUserResourceAccess(userId: string, resourceId: number) {
|
||||
export async function getUserResourceAccess(
|
||||
userId: string,
|
||||
resourceId: number
|
||||
) {
|
||||
if (config.isManagedMode()) {
|
||||
try {
|
||||
const response = await axios.get(`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/user/${userId}/resource/${resourceId}/access`, await tokenManager.getAuthHeader());
|
||||
const response = await axios.get(
|
||||
`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/user/${userId}/resource/${resourceId}/access`,
|
||||
await tokenManager.getAuthHeader()
|
||||
);
|
||||
return response.data.data;
|
||||
} catch (error) {
|
||||
if (axios.isAxiosError(error)) {
|
||||
@@ -246,10 +262,15 @@ export async function getUserResourceAccess(userId: string, resourceId: number)
|
||||
/**
|
||||
* Get resource rules for a given resource
|
||||
*/
|
||||
export async function getResourceRules(resourceId: number): Promise<ResourceRule[]> {
|
||||
export async function getResourceRules(
|
||||
resourceId: number
|
||||
): Promise<ResourceRule[]> {
|
||||
if (config.isManagedMode()) {
|
||||
try {
|
||||
const response = await axios.get(`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/resource/${resourceId}/rules`, await tokenManager.getAuthHeader());
|
||||
const response = await axios.get(
|
||||
`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/resource/${resourceId}/rules`,
|
||||
await tokenManager.getAuthHeader()
|
||||
);
|
||||
return response.data.data;
|
||||
} catch (error) {
|
||||
if (axios.isAxiosError(error)) {
|
||||
@@ -275,3 +296,50 @@ export async function getResourceRules(resourceId: number): Promise<ResourceRule
|
||||
|
||||
return rules;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get organization login page
|
||||
*/
|
||||
export async function getOrgLoginPage(
|
||||
orgId: string
|
||||
): Promise<LoginPage | null> {
|
||||
if (config.isManagedMode()) {
|
||||
try {
|
||||
const response = await axios.get(
|
||||
`${config.getRawConfig().managed?.endpoint}/api/v1/hybrid/org/${orgId}/login-page`,
|
||||
await tokenManager.getAuthHeader()
|
||||
);
|
||||
return response.data.data;
|
||||
} catch (error) {
|
||||
if (axios.isAxiosError(error)) {
|
||||
logger.error("Error fetching config in verify session:", {
|
||||
message: error.message,
|
||||
code: error.code,
|
||||
status: error.response?.status,
|
||||
statusText: error.response?.statusText,
|
||||
url: error.config?.url,
|
||||
method: error.config?.method
|
||||
});
|
||||
} else {
|
||||
logger.error("Error fetching config in verify session:", error);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
const [result] = await db
|
||||
.select()
|
||||
.from(loginPageOrg)
|
||||
.where(eq(loginPageOrg.orgId, orgId))
|
||||
.innerJoin(
|
||||
loginPage,
|
||||
eq(loginPageOrg.loginPageId, loginPage.loginPageId)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!result) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return result?.loginPage;
|
||||
}
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
export * from "./driver";
|
||||
export * from "./schema";
|
||||
export * from "./privateSchema";
|
||||
239
server/db/sqlite/privateSchema.ts
Normal file
239
server/db/sqlite/privateSchema.ts
Normal file
@@ -0,0 +1,239 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import {
|
||||
sqliteTable,
|
||||
integer,
|
||||
text,
|
||||
real
|
||||
} from "drizzle-orm/sqlite-core";
|
||||
import { InferSelectModel } from "drizzle-orm";
|
||||
import { domains, orgs, targets, users, exitNodes, sessions } from "./schema";
|
||||
|
||||
export const certificates = sqliteTable("certificates", {
|
||||
certId: integer("certId").primaryKey({ autoIncrement: true }),
|
||||
domain: text("domain").notNull().unique(),
|
||||
domainId: text("domainId").references(() => domains.domainId, {
|
||||
onDelete: "cascade"
|
||||
}),
|
||||
wildcard: integer("wildcard", { mode: "boolean" }).default(false),
|
||||
status: text("status").notNull().default("pending"), // pending, requested, valid, expired, failed
|
||||
expiresAt: integer("expiresAt"),
|
||||
lastRenewalAttempt: integer("lastRenewalAttempt"),
|
||||
createdAt: integer("createdAt").notNull(),
|
||||
updatedAt: integer("updatedAt").notNull(),
|
||||
orderId: text("orderId"),
|
||||
errorMessage: text("errorMessage"),
|
||||
renewalCount: integer("renewalCount").default(0),
|
||||
certFile: text("certFile"),
|
||||
keyFile: text("keyFile")
|
||||
});
|
||||
|
||||
export const dnsChallenge = sqliteTable("dnsChallenges", {
|
||||
dnsChallengeId: integer("dnsChallengeId").primaryKey({ autoIncrement: true }),
|
||||
domain: text("domain").notNull(),
|
||||
token: text("token").notNull(),
|
||||
keyAuthorization: text("keyAuthorization").notNull(),
|
||||
createdAt: integer("createdAt").notNull(),
|
||||
expiresAt: integer("expiresAt").notNull(),
|
||||
completed: integer("completed", { mode: "boolean" }).default(false)
|
||||
});
|
||||
|
||||
export const account = sqliteTable("account", {
|
||||
accountId: integer("accountId").primaryKey({ autoIncrement: true }),
|
||||
userId: text("userId")
|
||||
.notNull()
|
||||
.references(() => users.userId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const customers = sqliteTable("customers", {
|
||||
customerId: text("customerId").primaryKey().notNull(),
|
||||
orgId: text("orgId")
|
||||
.notNull()
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" }),
|
||||
// accountId: integer("accountId")
|
||||
// .references(() => account.accountId, { onDelete: "cascade" }), // Optional, if using accounts
|
||||
email: text("email"),
|
||||
name: text("name"),
|
||||
phone: text("phone"),
|
||||
address: text("address"),
|
||||
createdAt: integer("createdAt").notNull(),
|
||||
updatedAt: integer("updatedAt").notNull()
|
||||
});
|
||||
|
||||
export const subscriptions = sqliteTable("subscriptions", {
|
||||
subscriptionId: text("subscriptionId")
|
||||
.primaryKey()
|
||||
.notNull(),
|
||||
customerId: text("customerId")
|
||||
.notNull()
|
||||
.references(() => customers.customerId, { onDelete: "cascade" }),
|
||||
status: text("status").notNull().default("active"), // active, past_due, canceled, unpaid
|
||||
canceledAt: integer("canceledAt"),
|
||||
createdAt: integer("createdAt").notNull(),
|
||||
updatedAt: integer("updatedAt"),
|
||||
billingCycleAnchor: integer("billingCycleAnchor")
|
||||
});
|
||||
|
||||
export const subscriptionItems = sqliteTable("subscriptionItems", {
|
||||
subscriptionItemId: integer("subscriptionItemId").primaryKey({ autoIncrement: true }),
|
||||
subscriptionId: text("subscriptionId")
|
||||
.notNull()
|
||||
.references(() => subscriptions.subscriptionId, {
|
||||
onDelete: "cascade"
|
||||
}),
|
||||
planId: text("planId").notNull(),
|
||||
priceId: text("priceId"),
|
||||
meterId: text("meterId"),
|
||||
unitAmount: real("unitAmount"),
|
||||
tiers: text("tiers"),
|
||||
interval: text("interval"),
|
||||
currentPeriodStart: integer("currentPeriodStart"),
|
||||
currentPeriodEnd: integer("currentPeriodEnd"),
|
||||
name: text("name")
|
||||
});
|
||||
|
||||
export const accountDomains = sqliteTable("accountDomains", {
|
||||
accountId: integer("accountId")
|
||||
.notNull()
|
||||
.references(() => account.accountId, { onDelete: "cascade" }),
|
||||
domainId: text("domainId")
|
||||
.notNull()
|
||||
.references(() => domains.domainId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const usage = sqliteTable("usage", {
|
||||
usageId: text("usageId").primaryKey(),
|
||||
featureId: text("featureId").notNull(),
|
||||
orgId: text("orgId")
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" })
|
||||
.notNull(),
|
||||
meterId: text("meterId"),
|
||||
instantaneousValue: real("instantaneousValue"),
|
||||
latestValue: real("latestValue").notNull(),
|
||||
previousValue: real("previousValue"),
|
||||
updatedAt: integer("updatedAt").notNull(),
|
||||
rolledOverAt: integer("rolledOverAt"),
|
||||
nextRolloverAt: integer("nextRolloverAt")
|
||||
});
|
||||
|
||||
export const limits = sqliteTable("limits", {
|
||||
limitId: text("limitId").primaryKey(),
|
||||
featureId: text("featureId").notNull(),
|
||||
orgId: text("orgId")
|
||||
.references(() => orgs.orgId, {
|
||||
onDelete: "cascade"
|
||||
})
|
||||
.notNull(),
|
||||
value: real("value"),
|
||||
description: text("description")
|
||||
});
|
||||
|
||||
export const usageNotifications = sqliteTable("usageNotifications", {
|
||||
notificationId: integer("notificationId").primaryKey({ autoIncrement: true }),
|
||||
orgId: text("orgId")
|
||||
.notNull()
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" }),
|
||||
featureId: text("featureId").notNull(),
|
||||
limitId: text("limitId").notNull(),
|
||||
notificationType: text("notificationType").notNull(),
|
||||
sentAt: integer("sentAt").notNull()
|
||||
});
|
||||
|
||||
export const domainNamespaces = sqliteTable("domainNamespaces", {
|
||||
domainNamespaceId: text("domainNamespaceId").primaryKey(),
|
||||
domainId: text("domainId")
|
||||
.references(() => domains.domainId, {
|
||||
onDelete: "set null"
|
||||
})
|
||||
.notNull()
|
||||
});
|
||||
|
||||
export const exitNodeOrgs = sqliteTable("exitNodeOrgs", {
|
||||
exitNodeId: integer("exitNodeId")
|
||||
.notNull()
|
||||
.references(() => exitNodes.exitNodeId, { onDelete: "cascade" }),
|
||||
orgId: text("orgId")
|
||||
.notNull()
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const remoteExitNodes = sqliteTable("remoteExitNode", {
|
||||
remoteExitNodeId: text("id").primaryKey(),
|
||||
secretHash: text("secretHash").notNull(),
|
||||
dateCreated: text("dateCreated").notNull(),
|
||||
version: text("version"),
|
||||
exitNodeId: integer("exitNodeId").references(() => exitNodes.exitNodeId, {
|
||||
onDelete: "cascade"
|
||||
})
|
||||
});
|
||||
|
||||
export const remoteExitNodeSessions = sqliteTable("remoteExitNodeSession", {
|
||||
sessionId: text("id").primaryKey(),
|
||||
remoteExitNodeId: text("remoteExitNodeId")
|
||||
.notNull()
|
||||
.references(() => remoteExitNodes.remoteExitNodeId, {
|
||||
onDelete: "cascade"
|
||||
}),
|
||||
expiresAt: integer("expiresAt").notNull()
|
||||
});
|
||||
|
||||
export const loginPage = sqliteTable("loginPage", {
|
||||
loginPageId: integer("loginPageId").primaryKey({ autoIncrement: true }),
|
||||
subdomain: text("subdomain"),
|
||||
fullDomain: text("fullDomain"),
|
||||
exitNodeId: integer("exitNodeId").references(() => exitNodes.exitNodeId, {
|
||||
onDelete: "set null"
|
||||
}),
|
||||
domainId: text("domainId").references(() => domains.domainId, {
|
||||
onDelete: "set null"
|
||||
})
|
||||
});
|
||||
|
||||
export const loginPageOrg = sqliteTable("loginPageOrg", {
|
||||
loginPageId: integer("loginPageId")
|
||||
.notNull()
|
||||
.references(() => loginPage.loginPageId, { onDelete: "cascade" }),
|
||||
orgId: text("orgId")
|
||||
.notNull()
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const sessionTransferToken = sqliteTable("sessionTransferToken", {
|
||||
token: text("token").primaryKey(),
|
||||
sessionId: text("sessionId")
|
||||
.notNull()
|
||||
.references(() => sessions.sessionId, {
|
||||
onDelete: "cascade"
|
||||
}),
|
||||
encryptedSession: text("encryptedSession").notNull(),
|
||||
expiresAt: integer("expiresAt").notNull()
|
||||
});
|
||||
|
||||
export type Limit = InferSelectModel<typeof limits>;
|
||||
export type Account = InferSelectModel<typeof account>;
|
||||
export type Certificate = InferSelectModel<typeof certificates>;
|
||||
export type DnsChallenge = InferSelectModel<typeof dnsChallenge>;
|
||||
export type Customer = InferSelectModel<typeof customers>;
|
||||
export type Subscription = InferSelectModel<typeof subscriptions>;
|
||||
export type SubscriptionItem = InferSelectModel<typeof subscriptionItems>;
|
||||
export type Usage = InferSelectModel<typeof usage>;
|
||||
export type UsageLimit = InferSelectModel<typeof limits>;
|
||||
export type AccountDomain = InferSelectModel<typeof accountDomains>;
|
||||
export type UsageNotification = InferSelectModel<typeof usageNotifications>;
|
||||
export type RemoteExitNode = InferSelectModel<typeof remoteExitNodes>;
|
||||
export type RemoteExitNodeSession = InferSelectModel<
|
||||
typeof remoteExitNodeSessions
|
||||
>;
|
||||
export type ExitNodeOrg = InferSelectModel<typeof exitNodeOrgs>;
|
||||
export type LoginPage = InferSelectModel<typeof loginPage>;
|
||||
@@ -140,6 +140,27 @@ export const targets = sqliteTable("targets", {
|
||||
rewritePathType: text("rewritePathType") // exact, prefix, regex, stripPrefix
|
||||
});
|
||||
|
||||
export const targetHealthCheck = sqliteTable("targetHealthCheck", {
|
||||
targetHealthCheckId: integer("targetHealthCheckId").primaryKey({ autoIncrement: true }),
|
||||
targetId: integer("targetId")
|
||||
.notNull()
|
||||
.references(() => targets.targetId, { onDelete: "cascade" }),
|
||||
hcEnabled: integer("hcEnabled", { mode: "boolean" }).notNull().default(false),
|
||||
hcPath: text("hcPath"),
|
||||
hcScheme: text("hcScheme"),
|
||||
hcMode: text("hcMode").default("http"),
|
||||
hcHostname: text("hcHostname"),
|
||||
hcPort: integer("hcPort"),
|
||||
hcInterval: integer("hcInterval").default(30), // in seconds
|
||||
hcUnhealthyInterval: integer("hcUnhealthyInterval").default(30), // in seconds
|
||||
hcTimeout: integer("hcTimeout").default(5), // in seconds
|
||||
hcHeaders: text("hcHeaders"),
|
||||
hcFollowRedirects: integer("hcFollowRedirects", { mode: "boolean" }).default(true),
|
||||
hcMethod: text("hcMethod").default("GET"),
|
||||
hcStatus: integer("hcStatus"), // http code
|
||||
hcHealth: text("hcHealth").default("unknown") // "unknown", "healthy", "unhealthy"
|
||||
});
|
||||
|
||||
export const exitNodes = sqliteTable("exitNodes", {
|
||||
exitNodeId: integer("exitNodeId").primaryKey({ autoIncrement: true }),
|
||||
name: text("name").notNull(),
|
||||
@@ -458,18 +479,6 @@ export const userResources = sqliteTable("userResources", {
|
||||
.references(() => resources.resourceId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const limitsTable = sqliteTable("limits", {
|
||||
limitId: integer("limitId").primaryKey({ autoIncrement: true }),
|
||||
orgId: text("orgId")
|
||||
.references(() => orgs.orgId, {
|
||||
onDelete: "cascade"
|
||||
})
|
||||
.notNull(),
|
||||
name: text("name").notNull(),
|
||||
value: integer("value").notNull(),
|
||||
description: text("description")
|
||||
});
|
||||
|
||||
export const userInvites = sqliteTable("userInvites", {
|
||||
inviteId: text("inviteId").primaryKey(),
|
||||
orgId: text("orgId")
|
||||
@@ -714,7 +723,6 @@ export type RoleSite = InferSelectModel<typeof roleSites>;
|
||||
export type UserSite = InferSelectModel<typeof userSites>;
|
||||
export type RoleResource = InferSelectModel<typeof roleResources>;
|
||||
export type UserResource = InferSelectModel<typeof userResources>;
|
||||
export type Limit = InferSelectModel<typeof limitsTable>;
|
||||
export type UserInvite = InferSelectModel<typeof userInvites>;
|
||||
export type UserOrg = InferSelectModel<typeof userOrgs>;
|
||||
export type ResourceSession = InferSelectModel<typeof resourceSessions>;
|
||||
@@ -739,3 +747,4 @@ export type SiteResource = InferSelectModel<typeof siteResources>;
|
||||
export type OrgDomains = InferSelectModel<typeof orgDomains>;
|
||||
export type SetupToken = InferSelectModel<typeof setupTokens>;
|
||||
export type HostMeta = InferSelectModel<typeof hostMeta>;
|
||||
export type TargetHealthCheck = InferSelectModel<typeof targetHealthCheck>;
|
||||
@@ -11,7 +11,7 @@ export async function sendEmail(
|
||||
from: string | undefined;
|
||||
to: string | undefined;
|
||||
subject: string;
|
||||
},
|
||||
}
|
||||
) {
|
||||
if (!emailClient) {
|
||||
logger.warn("Email client not configured, skipping email send");
|
||||
@@ -25,16 +25,16 @@ export async function sendEmail(
|
||||
|
||||
const emailHtml = await render(template);
|
||||
|
||||
const appName = "Pangolin";
|
||||
const appName = config.getRawPrivateConfig().branding?.app_name || "Pangolin";
|
||||
|
||||
await emailClient.sendMail({
|
||||
from: {
|
||||
name: opts.name || appName,
|
||||
address: opts.from,
|
||||
address: opts.from
|
||||
},
|
||||
to: opts.to,
|
||||
subject: opts.subject,
|
||||
html: emailHtml,
|
||||
html: emailHtml
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import React from "react";
|
||||
import { Body, Head, Html, Preview, Tailwind } from "@react-email/components";
|
||||
import { themeColors } from "./lib/theme";
|
||||
import {
|
||||
EmailContainer,
|
||||
EmailFooter,
|
||||
EmailGreeting,
|
||||
EmailHeading,
|
||||
EmailLetterHead,
|
||||
EmailSignature,
|
||||
EmailText
|
||||
} from "./components/Email";
|
||||
|
||||
interface Props {
|
||||
email: string;
|
||||
limitName: string;
|
||||
currentUsage: number;
|
||||
usageLimit: number;
|
||||
billingLink: string; // Link to billing page
|
||||
}
|
||||
|
||||
export const NotifyUsageLimitApproaching = ({ email, limitName, currentUsage, usageLimit, billingLink }: Props) => {
|
||||
const previewText = `Your usage for ${limitName} is approaching the limit.`;
|
||||
const usagePercentage = Math.round((currentUsage / usageLimit) * 100);
|
||||
|
||||
return (
|
||||
<Html>
|
||||
<Head />
|
||||
<Preview>{previewText}</Preview>
|
||||
<Tailwind config={themeColors}>
|
||||
<Body className="font-sans bg-gray-50">
|
||||
<EmailContainer>
|
||||
<EmailLetterHead />
|
||||
|
||||
<EmailHeading>Usage Limit Warning</EmailHeading>
|
||||
|
||||
<EmailGreeting>Hi there,</EmailGreeting>
|
||||
|
||||
<EmailText>
|
||||
We wanted to let you know that your usage for <strong>{limitName}</strong> is approaching your plan limit.
|
||||
</EmailText>
|
||||
|
||||
<EmailText>
|
||||
<strong>Current Usage:</strong> {currentUsage} of {usageLimit} ({usagePercentage}%)
|
||||
</EmailText>
|
||||
|
||||
<EmailText>
|
||||
Once you reach your limit, some functionality may be restricted or your sites may disconnect until you upgrade your plan or your usage resets.
|
||||
</EmailText>
|
||||
|
||||
<EmailText>
|
||||
To avoid any interruption to your service, we recommend upgrading your plan or monitoring your usage closely. You can <a href={billingLink}>upgrade your plan here</a>.
|
||||
</EmailText>
|
||||
|
||||
<EmailText>
|
||||
If you have any questions or need assistance, please don't hesitate to reach out to our support team.
|
||||
</EmailText>
|
||||
|
||||
<EmailFooter>
|
||||
<EmailSignature />
|
||||
</EmailFooter>
|
||||
</EmailContainer>
|
||||
</Body>
|
||||
</Tailwind>
|
||||
</Html>
|
||||
);
|
||||
};
|
||||
|
||||
export default NotifyUsageLimitApproaching;
|
||||
84
server/emails/templates/PrivateNotifyUsageLimitReached.tsx
Normal file
84
server/emails/templates/PrivateNotifyUsageLimitReached.tsx
Normal file
@@ -0,0 +1,84 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import React from "react";
|
||||
import { Body, Head, Html, Preview, Tailwind } from "@react-email/components";
|
||||
import { themeColors } from "./lib/theme";
|
||||
import {
|
||||
EmailContainer,
|
||||
EmailFooter,
|
||||
EmailGreeting,
|
||||
EmailHeading,
|
||||
EmailLetterHead,
|
||||
EmailSignature,
|
||||
EmailText
|
||||
} from "./components/Email";
|
||||
|
||||
interface Props {
|
||||
email: string;
|
||||
limitName: string;
|
||||
currentUsage: number;
|
||||
usageLimit: number;
|
||||
billingLink: string; // Link to billing page
|
||||
}
|
||||
|
||||
export const NotifyUsageLimitReached = ({ email, limitName, currentUsage, usageLimit, billingLink }: Props) => {
|
||||
const previewText = `You've reached your ${limitName} usage limit - Action required`;
|
||||
const usagePercentage = Math.round((currentUsage / usageLimit) * 100);
|
||||
|
||||
return (
|
||||
<Html>
|
||||
<Head />
|
||||
<Preview>{previewText}</Preview>
|
||||
<Tailwind config={themeColors}>
|
||||
<Body className="font-sans bg-gray-50">
|
||||
<EmailContainer>
|
||||
<EmailLetterHead />
|
||||
|
||||
<EmailHeading>Usage Limit Reached - Action Required</EmailHeading>
|
||||
|
||||
<EmailGreeting>Hi there,</EmailGreeting>
|
||||
|
||||
<EmailText>
|
||||
You have reached your usage limit for <strong>{limitName}</strong>.
|
||||
</EmailText>
|
||||
|
||||
<EmailText>
|
||||
<strong>Current Usage:</strong> {currentUsage} of {usageLimit} ({usagePercentage}%)
|
||||
</EmailText>
|
||||
|
||||
<EmailText>
|
||||
<strong>Important:</strong> Your functionality may now be restricted and your sites may disconnect until you either upgrade your plan or your usage resets. To prevent any service interruption, immediate action is recommended.
|
||||
</EmailText>
|
||||
|
||||
<EmailText>
|
||||
<strong>What you can do:</strong>
|
||||
<br />• <a href={billingLink} style={{ color: '#2563eb', fontWeight: 'bold' }}>Upgrade your plan immediately</a> to restore full functionality
|
||||
<br />• Monitor your usage to stay within limits in the future
|
||||
</EmailText>
|
||||
|
||||
<EmailText>
|
||||
If you have any questions or need immediate assistance, please contact our support team right away.
|
||||
</EmailText>
|
||||
|
||||
<EmailFooter>
|
||||
<EmailSignature />
|
||||
</EmailFooter>
|
||||
</EmailContainer>
|
||||
</Body>
|
||||
</Tailwind>
|
||||
</Html>
|
||||
);
|
||||
};
|
||||
|
||||
export default NotifyUsageLimitReached;
|
||||
@@ -17,7 +17,7 @@ export function EmailLetterHead() {
|
||||
<div className="px-6 pt-8 pb-2 text-center">
|
||||
<Img
|
||||
src="https://fossorial-public-assets.s3.us-east-1.amazonaws.com/word_mark_black.png"
|
||||
alt="Fossorial"
|
||||
alt="Pangolin Logo"
|
||||
width="120"
|
||||
height="auto"
|
||||
className="mx-auto"
|
||||
@@ -107,7 +107,7 @@ export function EmailSignature() {
|
||||
<p className="mb-2">
|
||||
Best regards,
|
||||
<br />
|
||||
<strong>The Fossorial Team</strong>
|
||||
<strong>The Pangolin Team</strong>
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -3,7 +3,7 @@ import config from "@server/lib/config";
|
||||
import { createWebSocketClient } from "./routers/ws/client";
|
||||
import { addPeer, deletePeer } from "./routers/gerbil/peers";
|
||||
import { db, exitNodes } from "./db";
|
||||
import { TraefikConfigManager } from "./lib/traefikConfig";
|
||||
import { TraefikConfigManager } from "./lib/traefik/TraefikConfigManager";
|
||||
import { tokenManager } from "./lib/tokenManager";
|
||||
import { APP_VERSION } from "./lib/consts";
|
||||
import axios from "axios";
|
||||
|
||||
@@ -5,13 +5,13 @@ import { runSetupFunctions } from "./setup";
|
||||
import { createApiServer } from "./apiServer";
|
||||
import { createNextServer } from "./nextServer";
|
||||
import { createInternalServer } from "./internalServer";
|
||||
import { ApiKey, ApiKeyOrg, Session, User, UserOrg } from "@server/db";
|
||||
import { ApiKey, ApiKeyOrg, RemoteExitNode, Session, User, UserOrg } from "@server/db";
|
||||
import { createIntegrationApiServer } from "./integrationApiServer";
|
||||
import { createHybridClientServer } from "./hybridServer";
|
||||
import config from "@server/lib/config";
|
||||
import { setHostMeta } from "@server/lib/hostMeta";
|
||||
import { initTelemetryClient } from "./lib/telemetry.js";
|
||||
import { TraefikConfigManager } from "./lib/traefikConfig.js";
|
||||
import { TraefikConfigManager } from "./lib/traefik/TraefikConfigManager.js";
|
||||
|
||||
async function startServers() {
|
||||
await setHostMeta();
|
||||
@@ -63,6 +63,7 @@ declare global {
|
||||
userOrgRoleId?: number;
|
||||
userOrgId?: string;
|
||||
userOrgIds?: string[];
|
||||
remoteExitNode?: RemoteExitNode;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,10 @@ import helmet from "helmet";
|
||||
import swaggerUi from "swagger-ui-express";
|
||||
import { OpenApiGeneratorV3 } from "@asteasolutions/zod-to-openapi";
|
||||
import { registry } from "./openApi";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import { APP_PATH } from "./lib/consts";
|
||||
import yaml from "js-yaml";
|
||||
|
||||
const dev = process.env.ENVIRONMENT !== "prod";
|
||||
const externalPort = config.getRawConfig().server.integration_port;
|
||||
@@ -92,7 +96,7 @@ function getOpenApiDocumentation() {
|
||||
|
||||
const generator = new OpenApiGeneratorV3(registry.definitions);
|
||||
|
||||
return generator.generateDocument({
|
||||
const generated = generator.generateDocument({
|
||||
openapi: "3.0.0",
|
||||
info: {
|
||||
version: "v1",
|
||||
@@ -100,4 +104,12 @@ function getOpenApiDocumentation() {
|
||||
},
|
||||
servers: [{ url: "/v1" }]
|
||||
});
|
||||
|
||||
// convert to yaml and save to file
|
||||
const outputPath = path.join(APP_PATH, "openapi.yaml");
|
||||
const yamlOutput = yaml.dump(generated);
|
||||
fs.writeFileSync(outputPath, yamlOutput, "utf8");
|
||||
logger.info(`OpenAPI documentation saved to ${outputPath}`);
|
||||
|
||||
return generated;
|
||||
}
|
||||
|
||||
@@ -69,9 +69,16 @@ export async function applyBlueprint(
|
||||
`Updating target ${target.targetId} on site ${site.sites.siteId}`
|
||||
);
|
||||
|
||||
// see if you can find a matching target health check from the healthchecksToUpdate array
|
||||
const matchingHealthcheck =
|
||||
result.healthchecksToUpdate.find(
|
||||
(hc) => hc.targetId === target.targetId
|
||||
);
|
||||
|
||||
await addProxyTargets(
|
||||
site.newt.newtId,
|
||||
[target],
|
||||
matchingHealthcheck ? [matchingHealthcheck] : [],
|
||||
result.proxyResource.protocol,
|
||||
result.proxyResource.proxyPort
|
||||
);
|
||||
|
||||
@@ -8,6 +8,8 @@ import {
|
||||
roleResources,
|
||||
roles,
|
||||
Target,
|
||||
TargetHealthCheck,
|
||||
targetHealthCheck,
|
||||
Transaction,
|
||||
userOrgs,
|
||||
userResources,
|
||||
@@ -22,6 +24,7 @@ import {
|
||||
TargetData
|
||||
} from "./types";
|
||||
import logger from "@server/logger";
|
||||
import { createCertificate } from "@server/routers/private/certificates/createCertificate";
|
||||
import { pickPort } from "@server/routers/target/helpers";
|
||||
import { resourcePassword } from "@server/db";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
@@ -30,6 +33,7 @@ import { isValidCIDR, isValidIP, isValidUrlGlobPattern } from "../validators";
|
||||
export type ProxyResourcesResults = {
|
||||
proxyResource: Resource;
|
||||
targetsToUpdate: Target[];
|
||||
healthchecksToUpdate: TargetHealthCheck[];
|
||||
}[];
|
||||
|
||||
export async function updateProxyResources(
|
||||
@@ -43,7 +47,8 @@ export async function updateProxyResources(
|
||||
for (const [resourceNiceId, resourceData] of Object.entries(
|
||||
config["proxy-resources"]
|
||||
)) {
|
||||
const targetsToUpdate: Target[] = [];
|
||||
let targetsToUpdate: Target[] = [];
|
||||
let healthchecksToUpdate: TargetHealthCheck[] = [];
|
||||
let resource: Resource;
|
||||
|
||||
async function createTarget( // reusable function to create a target
|
||||
@@ -114,6 +119,33 @@ export async function updateProxyResources(
|
||||
.returning();
|
||||
|
||||
targetsToUpdate.push(newTarget);
|
||||
|
||||
const healthcheckData = targetData.healthcheck;
|
||||
|
||||
const hcHeaders = healthcheckData?.headers ? JSON.stringify(healthcheckData.headers) : null;
|
||||
|
||||
const [newHealthcheck] = await trx
|
||||
.insert(targetHealthCheck)
|
||||
.values({
|
||||
targetId: newTarget.targetId,
|
||||
hcEnabled: healthcheckData?.enabled || false,
|
||||
hcPath: healthcheckData?.path,
|
||||
hcScheme: healthcheckData?.scheme,
|
||||
hcMode: healthcheckData?.mode,
|
||||
hcHostname: healthcheckData?.hostname,
|
||||
hcPort: healthcheckData?.port,
|
||||
hcInterval: healthcheckData?.interval,
|
||||
hcUnhealthyInterval: healthcheckData?.unhealthyInterval,
|
||||
hcTimeout: healthcheckData?.timeout,
|
||||
hcHeaders: hcHeaders,
|
||||
hcFollowRedirects: healthcheckData?.followRedirects,
|
||||
hcMethod: healthcheckData?.method,
|
||||
hcStatus: healthcheckData?.status,
|
||||
hcHealth: "unknown"
|
||||
})
|
||||
.returning();
|
||||
|
||||
healthchecksToUpdate.push(newHealthcheck);
|
||||
}
|
||||
|
||||
// Find existing resource by niceId and orgId
|
||||
@@ -360,6 +392,64 @@ export async function updateProxyResources(
|
||||
|
||||
targetsToUpdate.push(finalUpdatedTarget);
|
||||
}
|
||||
|
||||
const healthcheckData = targetData.healthcheck;
|
||||
|
||||
const [oldHealthcheck] = await trx
|
||||
.select()
|
||||
.from(targetHealthCheck)
|
||||
.where(
|
||||
eq(
|
||||
targetHealthCheck.targetId,
|
||||
existingTarget.targetId
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
const hcHeaders = healthcheckData?.headers ? JSON.stringify(healthcheckData.headers) : null;
|
||||
|
||||
const [newHealthcheck] = await trx
|
||||
.update(targetHealthCheck)
|
||||
.set({
|
||||
hcEnabled: healthcheckData?.enabled || false,
|
||||
hcPath: healthcheckData?.path,
|
||||
hcScheme: healthcheckData?.scheme,
|
||||
hcMode: healthcheckData?.mode,
|
||||
hcHostname: healthcheckData?.hostname,
|
||||
hcPort: healthcheckData?.port,
|
||||
hcInterval: healthcheckData?.interval,
|
||||
hcUnhealthyInterval:
|
||||
healthcheckData?.unhealthyInterval,
|
||||
hcTimeout: healthcheckData?.timeout,
|
||||
hcHeaders: hcHeaders,
|
||||
hcFollowRedirects: healthcheckData?.followRedirects,
|
||||
hcMethod: healthcheckData?.method,
|
||||
hcStatus: healthcheckData?.status
|
||||
})
|
||||
.where(
|
||||
eq(
|
||||
targetHealthCheck.targetId,
|
||||
existingTarget.targetId
|
||||
)
|
||||
)
|
||||
.returning();
|
||||
|
||||
if (
|
||||
checkIfHealthcheckChanged(
|
||||
oldHealthcheck,
|
||||
newHealthcheck
|
||||
)
|
||||
) {
|
||||
healthchecksToUpdate.push(newHealthcheck);
|
||||
// if the target is not already in the targetsToUpdate array, add it
|
||||
if (
|
||||
!targetsToUpdate.find(
|
||||
(t) => t.targetId === updatedTarget.targetId
|
||||
)
|
||||
) {
|
||||
targetsToUpdate.push(updatedTarget);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
await createTarget(existingResource.resourceId, targetData);
|
||||
}
|
||||
@@ -573,7 +663,8 @@ export async function updateProxyResources(
|
||||
|
||||
results.push({
|
||||
proxyResource: resource,
|
||||
targetsToUpdate
|
||||
targetsToUpdate,
|
||||
healthchecksToUpdate
|
||||
});
|
||||
}
|
||||
|
||||
@@ -783,6 +874,36 @@ async function syncWhitelistUsers(
|
||||
}
|
||||
}
|
||||
|
||||
function checkIfHealthcheckChanged(
|
||||
existing: TargetHealthCheck | undefined,
|
||||
incoming: TargetHealthCheck | undefined
|
||||
) {
|
||||
if (!existing && incoming) return true;
|
||||
if (existing && !incoming) return true;
|
||||
if (!existing || !incoming) return false;
|
||||
|
||||
if (existing.hcEnabled !== incoming.hcEnabled) return true;
|
||||
if (existing.hcPath !== incoming.hcPath) return true;
|
||||
if (existing.hcScheme !== incoming.hcScheme) return true;
|
||||
if (existing.hcMode !== incoming.hcMode) return true;
|
||||
if (existing.hcHostname !== incoming.hcHostname) return true;
|
||||
if (existing.hcPort !== incoming.hcPort) return true;
|
||||
if (existing.hcInterval !== incoming.hcInterval) return true;
|
||||
if (existing.hcUnhealthyInterval !== incoming.hcUnhealthyInterval)
|
||||
return true;
|
||||
if (existing.hcTimeout !== incoming.hcTimeout) return true;
|
||||
if (existing.hcFollowRedirects !== incoming.hcFollowRedirects) return true;
|
||||
if (existing.hcMethod !== incoming.hcMethod) return true;
|
||||
if (existing.hcStatus !== incoming.hcStatus) return true;
|
||||
if (
|
||||
JSON.stringify(existing.hcHeaders) !==
|
||||
JSON.stringify(incoming.hcHeaders)
|
||||
)
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
function checkIfTargetChanged(
|
||||
existing: Target | undefined,
|
||||
incoming: Target | undefined
|
||||
@@ -832,6 +953,8 @@ async function getDomain(
|
||||
);
|
||||
}
|
||||
|
||||
await createCertificate(domain.domainId, fullDomain, trx);
|
||||
|
||||
return domain;
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,22 @@ export const SiteSchema = z.object({
|
||||
"docker-socket-enabled": z.boolean().optional().default(true)
|
||||
});
|
||||
|
||||
export const TargetHealthCheckSchema = z.object({
|
||||
hostname: z.string(),
|
||||
port: z.number().int().min(1).max(65535),
|
||||
enabled: z.boolean().optional().default(true),
|
||||
path: z.string().optional(),
|
||||
scheme: z.string().optional(),
|
||||
mode: z.string().default("http"),
|
||||
interval: z.number().int().default(30),
|
||||
unhealthyInterval: z.number().int().default(30),
|
||||
timeout: z.number().int().default(5),
|
||||
headers: z.array(z.object({ name: z.string(), value: z.string() })).nullable().optional().default(null),
|
||||
followRedirects: z.boolean().default(true),
|
||||
method: z.string().default("GET"),
|
||||
status: z.number().int().optional()
|
||||
});
|
||||
|
||||
// Schema for individual target within a resource
|
||||
export const TargetSchema = z.object({
|
||||
site: z.string().optional(),
|
||||
@@ -15,6 +31,7 @@ export const TargetSchema = z.object({
|
||||
"internal-port": z.number().int().min(1).max(65535).optional(),
|
||||
path: z.string().optional(),
|
||||
"path-match": z.enum(["exact", "prefix", "regex"]).optional().nullable(),
|
||||
healthcheck: TargetHealthCheckSchema.optional(),
|
||||
rewritePath: z.string().optional(),
|
||||
"rewrite-match": z.enum(["exact", "prefix", "regex", "stripPrefix"]).optional().nullable()
|
||||
});
|
||||
|
||||
29
server/lib/colorsSchema.ts
Normal file
29
server/lib/colorsSchema.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
import { z } from "zod";
|
||||
|
||||
export const colorsSchema = z.object({
|
||||
background: z.string().optional(),
|
||||
foreground: z.string().optional(),
|
||||
card: z.string().optional(),
|
||||
"card-foreground": z.string().optional(),
|
||||
popover: z.string().optional(),
|
||||
"popover-foreground": z.string().optional(),
|
||||
primary: z.string().optional(),
|
||||
"primary-foreground": z.string().optional(),
|
||||
secondary: z.string().optional(),
|
||||
"secondary-foreground": z.string().optional(),
|
||||
muted: z.string().optional(),
|
||||
"muted-foreground": z.string().optional(),
|
||||
accent: z.string().optional(),
|
||||
"accent-foreground": z.string().optional(),
|
||||
destructive: z.string().optional(),
|
||||
"destructive-foreground": z.string().optional(),
|
||||
border: z.string().optional(),
|
||||
input: z.string().optional(),
|
||||
ring: z.string().optional(),
|
||||
radius: z.string().optional(),
|
||||
"chart-1": z.string().optional(),
|
||||
"chart-2": z.string().optional(),
|
||||
"chart-3": z.string().optional(),
|
||||
"chart-4": z.string().optional(),
|
||||
"chart-5": z.string().optional()
|
||||
});
|
||||
@@ -6,9 +6,16 @@ import { eq } from "drizzle-orm";
|
||||
import { license } from "@server/license/license";
|
||||
import { configSchema, readConfigFile } from "./readConfigFile";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import {
|
||||
privateConfigSchema,
|
||||
readPrivateConfigFile
|
||||
} from "@server/lib/private/readConfigFile";
|
||||
import logger from "@server/logger";
|
||||
import { build } from "@server/build";
|
||||
|
||||
export class Config {
|
||||
private rawConfig!: z.infer<typeof configSchema>;
|
||||
private rawPrivateConfig!: z.infer<typeof privateConfigSchema>;
|
||||
|
||||
supporterData: SupporterKey | null = null;
|
||||
|
||||
@@ -30,6 +37,19 @@ export class Config {
|
||||
throw new Error(`Invalid configuration file: ${errors}`);
|
||||
}
|
||||
|
||||
const privateEnvironment = readPrivateConfigFile();
|
||||
|
||||
const {
|
||||
data: parsedPrivateConfig,
|
||||
success: privateSuccess,
|
||||
error: privateError
|
||||
} = privateConfigSchema.safeParse(privateEnvironment);
|
||||
|
||||
if (!privateSuccess) {
|
||||
const errors = fromError(privateError);
|
||||
throw new Error(`Invalid private configuration file: ${errors}`);
|
||||
}
|
||||
|
||||
if (
|
||||
// @ts-ignore
|
||||
parsedConfig.users ||
|
||||
@@ -89,7 +109,110 @@ export class Config {
|
||||
? "true"
|
||||
: "false";
|
||||
|
||||
if (parsedPrivateConfig.branding?.colors) {
|
||||
process.env.BRANDING_COLORS = JSON.stringify(
|
||||
parsedPrivateConfig.branding?.colors
|
||||
);
|
||||
}
|
||||
|
||||
if (parsedPrivateConfig.branding?.logo?.light_path) {
|
||||
process.env.BRANDING_LOGO_LIGHT_PATH =
|
||||
parsedPrivateConfig.branding?.logo?.light_path;
|
||||
}
|
||||
if (parsedPrivateConfig.branding?.logo?.dark_path) {
|
||||
process.env.BRANDING_LOGO_DARK_PATH =
|
||||
parsedPrivateConfig.branding?.logo?.dark_path || undefined;
|
||||
}
|
||||
|
||||
process.env.HIDE_SUPPORTER_KEY = parsedPrivateConfig.flags
|
||||
?.hide_supporter_key
|
||||
? "true"
|
||||
: "false";
|
||||
|
||||
if (build != "oss") {
|
||||
if (parsedPrivateConfig.branding?.logo?.light_path) {
|
||||
process.env.BRANDING_LOGO_LIGHT_PATH =
|
||||
parsedPrivateConfig.branding?.logo?.light_path;
|
||||
}
|
||||
if (parsedPrivateConfig.branding?.logo?.dark_path) {
|
||||
process.env.BRANDING_LOGO_DARK_PATH =
|
||||
parsedPrivateConfig.branding?.logo?.dark_path || undefined;
|
||||
}
|
||||
|
||||
process.env.BRANDING_LOGO_AUTH_WIDTH = parsedPrivateConfig.branding
|
||||
?.logo?.auth_page?.width
|
||||
? parsedPrivateConfig.branding?.logo?.auth_page?.width.toString()
|
||||
: undefined;
|
||||
process.env.BRANDING_LOGO_AUTH_HEIGHT = parsedPrivateConfig.branding
|
||||
?.logo?.auth_page?.height
|
||||
? parsedPrivateConfig.branding?.logo?.auth_page?.height.toString()
|
||||
: undefined;
|
||||
|
||||
process.env.BRANDING_LOGO_NAVBAR_WIDTH = parsedPrivateConfig
|
||||
.branding?.logo?.navbar?.width
|
||||
? parsedPrivateConfig.branding?.logo?.navbar?.width.toString()
|
||||
: undefined;
|
||||
process.env.BRANDING_LOGO_NAVBAR_HEIGHT = parsedPrivateConfig
|
||||
.branding?.logo?.navbar?.height
|
||||
? parsedPrivateConfig.branding?.logo?.navbar?.height.toString()
|
||||
: undefined;
|
||||
|
||||
process.env.BRANDING_FAVICON_PATH =
|
||||
parsedPrivateConfig.branding?.favicon_path;
|
||||
|
||||
process.env.BRANDING_APP_NAME =
|
||||
parsedPrivateConfig.branding?.app_name || "Pangolin";
|
||||
|
||||
if (parsedPrivateConfig.branding?.footer) {
|
||||
process.env.BRANDING_FOOTER = JSON.stringify(
|
||||
parsedPrivateConfig.branding?.footer
|
||||
);
|
||||
}
|
||||
|
||||
process.env.LOGIN_PAGE_TITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.login_page?.title_text || "";
|
||||
process.env.LOGIN_PAGE_SUBTITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.login_page?.subtitle_text || "";
|
||||
|
||||
process.env.SIGNUP_PAGE_TITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.signup_page?.title_text || "";
|
||||
process.env.SIGNUP_PAGE_SUBTITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.signup_page?.subtitle_text || "";
|
||||
|
||||
process.env.RESOURCE_AUTH_PAGE_HIDE_POWERED_BY =
|
||||
parsedPrivateConfig.branding?.resource_auth_page
|
||||
?.hide_powered_by === true
|
||||
? "true"
|
||||
: "false";
|
||||
process.env.RESOURCE_AUTH_PAGE_SHOW_LOGO =
|
||||
parsedPrivateConfig.branding?.resource_auth_page?.show_logo ===
|
||||
true
|
||||
? "true"
|
||||
: "false";
|
||||
process.env.RESOURCE_AUTH_PAGE_TITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.resource_auth_page?.title_text ||
|
||||
"";
|
||||
process.env.RESOURCE_AUTH_PAGE_SUBTITLE_TEXT =
|
||||
parsedPrivateConfig.branding?.resource_auth_page
|
||||
?.subtitle_text || "";
|
||||
|
||||
if (parsedPrivateConfig.branding?.background_image_path) {
|
||||
process.env.BACKGROUND_IMAGE_PATH =
|
||||
parsedPrivateConfig.branding?.background_image_path;
|
||||
}
|
||||
|
||||
if (parsedPrivateConfig.server.reo_client_id) {
|
||||
process.env.REO_CLIENT_ID =
|
||||
parsedPrivateConfig.server.reo_client_id;
|
||||
}
|
||||
}
|
||||
|
||||
if (parsedConfig.server.maxmind_db_path) {
|
||||
process.env.MAXMIND_DB_PATH = parsedConfig.server.maxmind_db_path;
|
||||
}
|
||||
|
||||
this.rawConfig = parsedConfig;
|
||||
this.rawPrivateConfig = parsedPrivateConfig;
|
||||
}
|
||||
|
||||
public async initServer() {
|
||||
@@ -107,7 +230,11 @@ export class Config {
|
||||
|
||||
private async checkKeyStatus() {
|
||||
const licenseStatus = await license.check();
|
||||
if (!licenseStatus.isHostLicensed) {
|
||||
if (
|
||||
!this.rawPrivateConfig.flags?.hide_supporter_key &&
|
||||
build != "oss" &&
|
||||
!licenseStatus.isHostLicensed
|
||||
) {
|
||||
this.checkSupporterKey();
|
||||
}
|
||||
}
|
||||
@@ -116,6 +243,10 @@ export class Config {
|
||||
return this.rawConfig;
|
||||
}
|
||||
|
||||
public getRawPrivateConfig() {
|
||||
return this.rawPrivateConfig;
|
||||
}
|
||||
|
||||
public getNoReplyEmail(): string | undefined {
|
||||
return (
|
||||
this.rawConfig.email?.no_reply || this.rawConfig.email?.smtp_user
|
||||
|
||||
@@ -11,3 +11,5 @@ export const APP_PATH = path.join("config");
|
||||
|
||||
export const configFilePath1 = path.join(APP_PATH, "config.yml");
|
||||
export const configFilePath2 = path.join(APP_PATH, "config.yaml");
|
||||
|
||||
export const privateConfigFilePath1 = path.join(APP_PATH, "privateConfig.yml");
|
||||
|
||||
39
server/lib/encryption.ts
Normal file
39
server/lib/encryption.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
import crypto from 'crypto';
|
||||
|
||||
export function encryptData(data: string, key: Buffer): string {
|
||||
const algorithm = 'aes-256-gcm';
|
||||
const iv = crypto.randomBytes(16);
|
||||
const cipher = crypto.createCipheriv(algorithm, key, iv);
|
||||
|
||||
let encrypted = cipher.update(data, 'utf8', 'hex');
|
||||
encrypted += cipher.final('hex');
|
||||
|
||||
const authTag = cipher.getAuthTag();
|
||||
|
||||
// Combine IV, auth tag, and encrypted data
|
||||
return iv.toString('hex') + ':' + authTag.toString('hex') + ':' + encrypted;
|
||||
}
|
||||
|
||||
// Helper function to decrypt data (you'll need this to read certificates)
|
||||
export function decryptData(encryptedData: string, key: Buffer): string {
|
||||
const algorithm = 'aes-256-gcm';
|
||||
const parts = encryptedData.split(':');
|
||||
|
||||
if (parts.length !== 3) {
|
||||
throw new Error('Invalid encrypted data format');
|
||||
}
|
||||
|
||||
const iv = Buffer.from(parts[0], 'hex');
|
||||
const authTag = Buffer.from(parts[1], 'hex');
|
||||
const encrypted = parts[2];
|
||||
|
||||
const decipher = crypto.createDecipheriv(algorithm, key, iv);
|
||||
decipher.setAuthTag(authTag);
|
||||
|
||||
let decrypted = decipher.update(encrypted, 'hex', 'utf8');
|
||||
decrypted += decipher.final('utf8');
|
||||
|
||||
return decrypted;
|
||||
}
|
||||
|
||||
// openssl rand -hex 32 > config/encryption.key
|
||||
@@ -16,7 +16,7 @@ export async function verifyExitNodeOrgAccess(
|
||||
return { hasAccess: true, exitNode };
|
||||
}
|
||||
|
||||
export async function listExitNodes(orgId: string, filterOnline = false) {
|
||||
export async function listExitNodes(orgId: string, filterOnline = false, noCloud = false) {
|
||||
// TODO: pick which nodes to send and ping better than just all of them that are not remote
|
||||
const allExitNodes = await db
|
||||
.select({
|
||||
@@ -57,4 +57,9 @@ export function selectBestExitNode(
|
||||
|
||||
export async function checkExitNodeOrg(exitNodeId: number, orgId: string) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function resolveExitNodes(hostname: string, publicKey: string) {
|
||||
// OSS version: simple implementation that returns empty array
|
||||
return [];
|
||||
}
|
||||
|
||||
33
server/lib/exitNodes/getCurrentExitNodeId.ts
Normal file
33
server/lib/exitNodes/getCurrentExitNodeId.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { eq } from "drizzle-orm";
|
||||
import { db, exitNodes } from "@server/db";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
let currentExitNodeId: number; // we really only need to look this up once per instance
|
||||
export async function getCurrentExitNodeId(): Promise<number> {
|
||||
if (!currentExitNodeId) {
|
||||
if (config.getRawConfig().gerbil.exit_node_name) {
|
||||
const exitNodeName = config.getRawConfig().gerbil.exit_node_name!;
|
||||
const [exitNode] = await db
|
||||
.select({
|
||||
exitNodeId: exitNodes.exitNodeId
|
||||
})
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.name, exitNodeName));
|
||||
if (exitNode) {
|
||||
currentExitNodeId = exitNode.exitNodeId;
|
||||
}
|
||||
} else {
|
||||
const [exitNode] = await db
|
||||
.select({
|
||||
exitNodeId: exitNodes.exitNodeId
|
||||
})
|
||||
.from(exitNodes)
|
||||
.limit(1);
|
||||
|
||||
if (exitNode) {
|
||||
currentExitNodeId = exitNode.exitNodeId;
|
||||
}
|
||||
}
|
||||
}
|
||||
return currentExitNodeId;
|
||||
}
|
||||
@@ -1,2 +1,33 @@
|
||||
export * from "./exitNodes";
|
||||
export * from "./shared";
|
||||
import { build } from "@server/build";
|
||||
|
||||
// Import both modules
|
||||
import * as exitNodesModule from "./exitNodes";
|
||||
import * as privateExitNodesModule from "./privateExitNodes";
|
||||
|
||||
// Conditionally export exit nodes implementation based on build type
|
||||
const exitNodesImplementation = build === "oss" ? exitNodesModule : privateExitNodesModule;
|
||||
|
||||
// Re-export all items from the selected implementation
|
||||
export const {
|
||||
verifyExitNodeOrgAccess,
|
||||
listExitNodes,
|
||||
selectBestExitNode,
|
||||
checkExitNodeOrg,
|
||||
resolveExitNodes
|
||||
} = exitNodesImplementation;
|
||||
|
||||
// Import communications modules
|
||||
import * as exitNodeCommsModule from "./exitNodeComms";
|
||||
import * as privateExitNodeCommsModule from "./privateExitNodeComms";
|
||||
|
||||
// Conditionally export communications implementation based on build type
|
||||
const exitNodeCommsImplementation = build === "oss" ? exitNodeCommsModule : privateExitNodeCommsModule;
|
||||
|
||||
// Re-export communications functions from the selected implementation
|
||||
export const {
|
||||
sendToExitNode
|
||||
} = exitNodeCommsImplementation;
|
||||
|
||||
// Re-export shared modules
|
||||
export * from "./subnet";
|
||||
export * from "./getCurrentExitNodeId";
|
||||
145
server/lib/exitNodes/privateExitNodeComms.ts
Normal file
145
server/lib/exitNodes/privateExitNodeComms.ts
Normal file
@@ -0,0 +1,145 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import axios from "axios";
|
||||
import logger from "@server/logger";
|
||||
import { db, ExitNode, remoteExitNodes } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { sendToClient } from "../../routers/ws";
|
||||
import { config } from "../config";
|
||||
|
||||
interface ExitNodeRequest {
|
||||
remoteType?: string;
|
||||
localPath: string;
|
||||
method?: "POST" | "DELETE" | "GET" | "PUT";
|
||||
data?: any;
|
||||
queryParams?: Record<string, string>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a request to an exit node, handling both remote and local exit nodes
|
||||
* @param exitNode The exit node to send the request to
|
||||
* @param request The request configuration
|
||||
* @returns Promise<any> Response data for local nodes, undefined for remote nodes
|
||||
*/
|
||||
export async function sendToExitNode(
|
||||
exitNode: ExitNode,
|
||||
request: ExitNodeRequest
|
||||
): Promise<any> {
|
||||
if (exitNode.type === "remoteExitNode" && request.remoteType) {
|
||||
const [remoteExitNode] = await db
|
||||
.select()
|
||||
.from(remoteExitNodes)
|
||||
.where(eq(remoteExitNodes.exitNodeId, exitNode.exitNodeId))
|
||||
.limit(1);
|
||||
|
||||
if (!remoteExitNode) {
|
||||
throw new Error(
|
||||
`Remote exit node with ID ${exitNode.exitNodeId} not found`
|
||||
);
|
||||
}
|
||||
|
||||
return sendToClient(remoteExitNode.remoteExitNodeId, {
|
||||
type: request.remoteType,
|
||||
data: request.data
|
||||
});
|
||||
} else {
|
||||
let hostname = exitNode.reachableAt;
|
||||
|
||||
logger.debug(`Exit node details:`, {
|
||||
type: exitNode.type,
|
||||
name: exitNode.name,
|
||||
reachableAt: exitNode.reachableAt,
|
||||
});
|
||||
|
||||
logger.debug(`Configured local exit node name: ${config.getRawConfig().gerbil.exit_node_name}`);
|
||||
|
||||
if (exitNode.name == config.getRawConfig().gerbil.exit_node_name) {
|
||||
hostname = config.getRawPrivateConfig().gerbil.local_exit_node_reachable_at;
|
||||
}
|
||||
|
||||
if (!hostname) {
|
||||
throw new Error(
|
||||
`Exit node with ID ${exitNode.exitNodeId} is not reachable`
|
||||
);
|
||||
}
|
||||
|
||||
logger.debug(`Sending request to exit node at ${hostname}`, {
|
||||
type: request.remoteType,
|
||||
data: request.data
|
||||
});
|
||||
|
||||
// Handle local exit node with HTTP API
|
||||
const method = request.method || "POST";
|
||||
let url = `${hostname}${request.localPath}`;
|
||||
|
||||
// Add query parameters if provided
|
||||
if (request.queryParams) {
|
||||
const params = new URLSearchParams(request.queryParams);
|
||||
url += `?${params.toString()}`;
|
||||
}
|
||||
|
||||
try {
|
||||
let response;
|
||||
|
||||
switch (method) {
|
||||
case "POST":
|
||||
response = await axios.post(url, request.data, {
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
timeout: 8000
|
||||
});
|
||||
break;
|
||||
case "DELETE":
|
||||
response = await axios.delete(url, {
|
||||
timeout: 8000
|
||||
});
|
||||
break;
|
||||
case "GET":
|
||||
response = await axios.get(url, {
|
||||
timeout: 8000
|
||||
});
|
||||
break;
|
||||
case "PUT":
|
||||
response = await axios.put(url, request.data, {
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
timeout: 8000
|
||||
});
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unsupported HTTP method: ${method}`);
|
||||
}
|
||||
|
||||
logger.debug(`Exit node request successful:`, {
|
||||
method,
|
||||
url,
|
||||
status: response.data.status
|
||||
});
|
||||
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
if (axios.isAxiosError(error)) {
|
||||
logger.error(
|
||||
`Error making ${method} request (can Pangolin see Gerbil HTTP API?) for exit node at ${hostname} (status: ${error.response?.status}): ${error.message}`
|
||||
);
|
||||
} else {
|
||||
logger.error(
|
||||
`Error making ${method} request for exit node at ${hostname}: ${error}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
379
server/lib/exitNodes/privateExitNodes.ts
Normal file
379
server/lib/exitNodes/privateExitNodes.ts
Normal file
@@ -0,0 +1,379 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import {
|
||||
db,
|
||||
exitNodes,
|
||||
exitNodeOrgs,
|
||||
resources,
|
||||
targets,
|
||||
sites,
|
||||
targetHealthCheck
|
||||
} from "@server/db";
|
||||
import logger from "@server/logger";
|
||||
import { ExitNodePingResult } from "@server/routers/newt";
|
||||
import { eq, and, or, ne, isNull } from "drizzle-orm";
|
||||
import axios from "axios";
|
||||
import config from "../config";
|
||||
|
||||
/**
|
||||
* Checks if an exit node is actually online by making HTTP requests to its endpoint/ping
|
||||
* Makes up to 3 attempts in parallel with small delays, returns as soon as one succeeds
|
||||
*/
|
||||
async function checkExitNodeOnlineStatus(
|
||||
endpoint: string | undefined
|
||||
): Promise<boolean> {
|
||||
if (!endpoint || endpoint == "") {
|
||||
// the endpoint can start out as a empty string
|
||||
return false;
|
||||
}
|
||||
|
||||
const maxAttempts = 3;
|
||||
const timeoutMs = 5000; // 5 second timeout per request
|
||||
const delayBetweenAttempts = 100; // 100ms delay between starting each attempt
|
||||
|
||||
// Create promises for all attempts with staggered delays
|
||||
const attemptPromises = Array.from({ length: maxAttempts }, async (_, index) => {
|
||||
const attemptNumber = index + 1;
|
||||
|
||||
// Add delay before each attempt (except the first)
|
||||
if (index > 0) {
|
||||
await new Promise((resolve) => setTimeout(resolve, delayBetweenAttempts * index));
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios.get(`http://${endpoint}/ping`, {
|
||||
timeout: timeoutMs,
|
||||
validateStatus: (status) => status === 200
|
||||
});
|
||||
|
||||
if (response.status === 200) {
|
||||
logger.debug(
|
||||
`Exit node ${endpoint} is online (attempt ${attemptNumber}/${maxAttempts})`
|
||||
);
|
||||
return { success: true, attemptNumber };
|
||||
}
|
||||
return { success: false, attemptNumber, error: 'Non-200 status' };
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : "Unknown error";
|
||||
logger.debug(
|
||||
`Exit node ${endpoint} ping failed (attempt ${attemptNumber}/${maxAttempts}): ${errorMessage}`
|
||||
);
|
||||
return { success: false, attemptNumber, error: errorMessage };
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
// Wait for the first successful response or all to fail
|
||||
const results = await Promise.allSettled(attemptPromises);
|
||||
|
||||
// Check if any attempt succeeded
|
||||
for (const result of results) {
|
||||
if (result.status === 'fulfilled' && result.value.success) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// All attempts failed
|
||||
logger.warn(
|
||||
`Exit node ${endpoint} is offline after ${maxAttempts} parallel attempts`
|
||||
);
|
||||
return false;
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Unexpected error checking exit node ${endpoint}: ${error instanceof Error ? error.message : "Unknown error"}`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function verifyExitNodeOrgAccess(
|
||||
exitNodeId: number,
|
||||
orgId: string
|
||||
) {
|
||||
const [result] = await db
|
||||
.select({
|
||||
exitNode: exitNodes,
|
||||
exitNodeOrgId: exitNodeOrgs.exitNodeId
|
||||
})
|
||||
.from(exitNodes)
|
||||
.leftJoin(
|
||||
exitNodeOrgs,
|
||||
and(
|
||||
eq(exitNodeOrgs.exitNodeId, exitNodes.exitNodeId),
|
||||
eq(exitNodeOrgs.orgId, orgId)
|
||||
)
|
||||
)
|
||||
.where(eq(exitNodes.exitNodeId, exitNodeId));
|
||||
|
||||
if (!result) {
|
||||
return { hasAccess: false, exitNode: null };
|
||||
}
|
||||
|
||||
const { exitNode } = result;
|
||||
|
||||
// If the exit node is type "gerbil", access is allowed
|
||||
if (exitNode.type === "gerbil") {
|
||||
return { hasAccess: true, exitNode };
|
||||
}
|
||||
|
||||
// If the exit node is type "remoteExitNode", check if it has org access
|
||||
if (exitNode.type === "remoteExitNode") {
|
||||
return { hasAccess: !!result.exitNodeOrgId, exitNode };
|
||||
}
|
||||
|
||||
// For any other type, deny access
|
||||
return { hasAccess: false, exitNode };
|
||||
}
|
||||
|
||||
export async function listExitNodes(orgId: string, filterOnline = false, noCloud = false) {
|
||||
const allExitNodes = await db
|
||||
.select({
|
||||
exitNodeId: exitNodes.exitNodeId,
|
||||
name: exitNodes.name,
|
||||
address: exitNodes.address,
|
||||
endpoint: exitNodes.endpoint,
|
||||
publicKey: exitNodes.publicKey,
|
||||
listenPort: exitNodes.listenPort,
|
||||
reachableAt: exitNodes.reachableAt,
|
||||
maxConnections: exitNodes.maxConnections,
|
||||
online: exitNodes.online,
|
||||
lastPing: exitNodes.lastPing,
|
||||
type: exitNodes.type,
|
||||
orgId: exitNodeOrgs.orgId,
|
||||
region: exitNodes.region
|
||||
})
|
||||
.from(exitNodes)
|
||||
.leftJoin(
|
||||
exitNodeOrgs,
|
||||
eq(exitNodes.exitNodeId, exitNodeOrgs.exitNodeId)
|
||||
)
|
||||
.where(
|
||||
or(
|
||||
// Include all exit nodes that are NOT of type remoteExitNode
|
||||
and(
|
||||
eq(exitNodes.type, "gerbil"),
|
||||
or(
|
||||
// only choose nodes that are in the same region
|
||||
eq(exitNodes.region, config.getRawPrivateConfig().app.region),
|
||||
isNull(exitNodes.region) // or for enterprise where region is not set
|
||||
)
|
||||
),
|
||||
// Include remoteExitNode types where the orgId matches the newt's organization
|
||||
and(
|
||||
eq(exitNodes.type, "remoteExitNode"),
|
||||
eq(exitNodeOrgs.orgId, orgId)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
// Filter the nodes. If there are NO remoteExitNodes then do nothing. If there are then remove all of the non-remoteExitNodes
|
||||
if (allExitNodes.length === 0) {
|
||||
logger.warn("No exit nodes found for ping request!");
|
||||
return [];
|
||||
}
|
||||
|
||||
// Enhanced online checking: consider node offline if either DB says offline OR HTTP ping fails
|
||||
const nodesWithRealOnlineStatus = await Promise.all(
|
||||
allExitNodes.map(async (node) => {
|
||||
// If database says it's online, verify with HTTP ping
|
||||
let online: boolean;
|
||||
if (filterOnline && node.type == "remoteExitNode") {
|
||||
try {
|
||||
const isActuallyOnline = await checkExitNodeOnlineStatus(
|
||||
node.endpoint
|
||||
);
|
||||
|
||||
// set the item in the database if it is offline
|
||||
if (isActuallyOnline != node.online) {
|
||||
await db
|
||||
.update(exitNodes)
|
||||
.set({ online: isActuallyOnline })
|
||||
.where(eq(exitNodes.exitNodeId, node.exitNodeId));
|
||||
}
|
||||
online = isActuallyOnline;
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Failed to check online status for exit node ${node.name} (${node.endpoint}): ${error instanceof Error ? error.message : "Unknown error"}`
|
||||
);
|
||||
online = false;
|
||||
}
|
||||
} else {
|
||||
online = node.online;
|
||||
}
|
||||
|
||||
return {
|
||||
...node,
|
||||
online
|
||||
};
|
||||
})
|
||||
);
|
||||
|
||||
const remoteExitNodes = nodesWithRealOnlineStatus.filter(
|
||||
(node) =>
|
||||
node.type === "remoteExitNode" && (!filterOnline || node.online)
|
||||
);
|
||||
const gerbilExitNodes = nodesWithRealOnlineStatus.filter(
|
||||
(node) => node.type === "gerbil" && (!filterOnline || node.online) && !noCloud
|
||||
);
|
||||
|
||||
// THIS PROVIDES THE FALL
|
||||
const exitNodesList =
|
||||
remoteExitNodes.length > 0 ? remoteExitNodes : gerbilExitNodes;
|
||||
|
||||
return exitNodesList;
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects the most suitable exit node from a list of ping results.
|
||||
*
|
||||
* The selection algorithm follows these steps:
|
||||
*
|
||||
* 1. **Filter Invalid Nodes**: Excludes nodes with errors or zero weight.
|
||||
*
|
||||
* 2. **Sort by Latency**: Sorts valid nodes in ascending order of latency.
|
||||
*
|
||||
* 3. **Preferred Selection**:
|
||||
* - If the lowest-latency node has sufficient capacity (≥10% weight),
|
||||
* check if a previously connected node is also acceptable.
|
||||
* - The previously connected node is preferred if its latency is within
|
||||
* 30ms or 15% of the best node’s latency.
|
||||
*
|
||||
* 4. **Fallback to Next Best**:
|
||||
* - If the lowest-latency node is under capacity, find the next node
|
||||
* with acceptable capacity.
|
||||
*
|
||||
* 5. **Final Fallback**:
|
||||
* - If no nodes meet the capacity threshold, fall back to the node
|
||||
* with the highest weight (i.e., most available capacity).
|
||||
*
|
||||
*/
|
||||
export function selectBestExitNode(
|
||||
pingResults: ExitNodePingResult[]
|
||||
): ExitNodePingResult | null {
|
||||
const MIN_CAPACITY_THRESHOLD = 0.1;
|
||||
const LATENCY_TOLERANCE_MS = 30;
|
||||
const LATENCY_TOLERANCE_PERCENT = 0.15;
|
||||
|
||||
// Filter out invalid nodes
|
||||
const validNodes = pingResults.filter((n) => !n.error && n.weight > 0);
|
||||
|
||||
if (validNodes.length === 0) {
|
||||
logger.error("No valid exit nodes available");
|
||||
return null;
|
||||
}
|
||||
|
||||
// Sort by latency (ascending)
|
||||
const sortedNodes = validNodes
|
||||
.slice()
|
||||
.sort((a, b) => a.latencyMs - b.latencyMs);
|
||||
const lowestLatencyNode = sortedNodes[0];
|
||||
|
||||
logger.debug(
|
||||
`Lowest latency node: ${lowestLatencyNode.exitNodeName} (${lowestLatencyNode.latencyMs} ms, weight=${lowestLatencyNode.weight.toFixed(2)})`
|
||||
);
|
||||
|
||||
// If lowest latency node has enough capacity, check if previously connected node is acceptable
|
||||
if (lowestLatencyNode.weight >= MIN_CAPACITY_THRESHOLD) {
|
||||
const previouslyConnectedNode = sortedNodes.find(
|
||||
(n) =>
|
||||
n.wasPreviouslyConnected && n.weight >= MIN_CAPACITY_THRESHOLD
|
||||
);
|
||||
|
||||
if (previouslyConnectedNode) {
|
||||
const latencyDiff =
|
||||
previouslyConnectedNode.latencyMs - lowestLatencyNode.latencyMs;
|
||||
const percentDiff = latencyDiff / lowestLatencyNode.latencyMs;
|
||||
|
||||
if (
|
||||
latencyDiff <= LATENCY_TOLERANCE_MS ||
|
||||
percentDiff <= LATENCY_TOLERANCE_PERCENT
|
||||
) {
|
||||
logger.info(
|
||||
`Sticking with previously connected node: ${previouslyConnectedNode.exitNodeName} ` +
|
||||
`(${previouslyConnectedNode.latencyMs} ms), latency diff = ${latencyDiff.toFixed(1)}ms ` +
|
||||
`/ ${(percentDiff * 100).toFixed(1)}%.`
|
||||
);
|
||||
return previouslyConnectedNode;
|
||||
}
|
||||
}
|
||||
|
||||
return lowestLatencyNode;
|
||||
}
|
||||
|
||||
// Otherwise, find the next node (after the lowest) that has enough capacity
|
||||
for (let i = 1; i < sortedNodes.length; i++) {
|
||||
const node = sortedNodes[i];
|
||||
if (node.weight >= MIN_CAPACITY_THRESHOLD) {
|
||||
logger.info(
|
||||
`Lowest latency node under capacity. Using next best: ${node.exitNodeName} ` +
|
||||
`(${node.latencyMs} ms, weight=${node.weight.toFixed(2)})`
|
||||
);
|
||||
return node;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: pick the highest weight node
|
||||
const fallbackNode = validNodes.reduce((a, b) =>
|
||||
a.weight > b.weight ? a : b
|
||||
);
|
||||
logger.warn(
|
||||
`No nodes with ≥10% weight. Falling back to highest capacity node: ${fallbackNode.exitNodeName}`
|
||||
);
|
||||
return fallbackNode;
|
||||
}
|
||||
|
||||
export async function checkExitNodeOrg(exitNodeId: number, orgId: string) {
|
||||
const [exitNodeOrg] = await db
|
||||
.select()
|
||||
.from(exitNodeOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(exitNodeOrgs.exitNodeId, exitNodeId),
|
||||
eq(exitNodeOrgs.orgId, orgId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!exitNodeOrg) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
export async function resolveExitNodes(hostname: string, publicKey: string) {
|
||||
const resourceExitNodes = await db
|
||||
.select({
|
||||
endpoint: exitNodes.endpoint,
|
||||
publicKey: exitNodes.publicKey,
|
||||
orgId: resources.orgId
|
||||
})
|
||||
.from(resources)
|
||||
.innerJoin(targets, eq(resources.resourceId, targets.resourceId))
|
||||
.leftJoin(
|
||||
targetHealthCheck,
|
||||
eq(targetHealthCheck.targetId, targets.targetId)
|
||||
)
|
||||
.innerJoin(sites, eq(targets.siteId, sites.siteId))
|
||||
.innerJoin(exitNodes, eq(sites.exitNodeId, exitNodes.exitNodeId))
|
||||
.where(
|
||||
and(
|
||||
eq(resources.fullDomain, hostname),
|
||||
ne(exitNodes.publicKey, publicKey),
|
||||
ne(targetHealthCheck.hcHealth, "unhealthy")
|
||||
)
|
||||
);
|
||||
|
||||
return resourceExitNodes;
|
||||
}
|
||||
@@ -1,10 +1,42 @@
|
||||
import logger from "@server/logger";
|
||||
import { maxmindLookup } from "@server/db/maxmind";
|
||||
import axios from "axios";
|
||||
import config from "./config";
|
||||
import { tokenManager } from "./tokenManager";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export async function getCountryCodeForIp(
|
||||
ip: string
|
||||
): Promise<string | undefined> {
|
||||
try {
|
||||
if (!maxmindLookup) {
|
||||
logger.warn(
|
||||
"MaxMind DB path not configured, cannot perform GeoIP lookup"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const result = maxmindLookup.get(ip);
|
||||
|
||||
if (!result || !result.country) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { country } = result;
|
||||
|
||||
logger.debug(
|
||||
`GeoIP lookup successful for IP ${ip}: ${country.iso_code}`
|
||||
);
|
||||
|
||||
return country.iso_code;
|
||||
} catch (error) {
|
||||
logger.error("Error fetching config in verify session:", error);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
export async function remoteGetCountryCodeForIp(
|
||||
ip: string
|
||||
): Promise<string | undefined> {
|
||||
try {
|
||||
const response = await axios.get(
|
||||
|
||||
@@ -1,8 +1,48 @@
|
||||
import { db, loginPage, loginPageOrg } from "@server/db";
|
||||
import config from "@server/lib/config";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
export async function generateOidcRedirectUrl(
|
||||
idpId: number,
|
||||
orgId?: string,
|
||||
loginPageId?: number
|
||||
): Promise<string> {
|
||||
let baseUrl: string | undefined;
|
||||
|
||||
const secure = config.getRawConfig().app.dashboard_url?.startsWith("https");
|
||||
const method = secure ? "https" : "http";
|
||||
|
||||
if (loginPageId) {
|
||||
const [res] = await db
|
||||
.select()
|
||||
.from(loginPage)
|
||||
.where(eq(loginPage.loginPageId, loginPageId))
|
||||
.limit(1);
|
||||
|
||||
if (res && res.fullDomain) {
|
||||
baseUrl = `${method}://${res.fullDomain}`;
|
||||
}
|
||||
} else if (orgId) {
|
||||
const [res] = await db
|
||||
.select()
|
||||
.from(loginPageOrg)
|
||||
.where(eq(loginPageOrg.orgId, orgId))
|
||||
.innerJoin(
|
||||
loginPage,
|
||||
eq(loginPage.loginPageId, loginPageOrg.loginPageId)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (res?.loginPage && res.loginPage.domainId && res.loginPage.fullDomain) {
|
||||
baseUrl = `${method}://${res.loginPage.fullDomain}`;
|
||||
}
|
||||
}
|
||||
|
||||
if (!baseUrl) {
|
||||
baseUrl = config.getRawConfig().app.dashboard_url!;
|
||||
}
|
||||
|
||||
export function generateOidcRedirectUrl(idpId: number) {
|
||||
const dashboardUrl = config.getRawConfig().app.dashboard_url;
|
||||
const redirectPath = `/auth/idp/${idpId}/oidc/callback`;
|
||||
const redirectUrl = new URL(redirectPath, dashboardUrl).toString();
|
||||
const redirectUrl = new URL(redirectPath, baseUrl!).toString();
|
||||
return redirectUrl;
|
||||
}
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
export * from "./response";
|
||||
export { tokenManager, TokenManager } from "./tokenManager";
|
||||
export * from "./geoip";
|
||||
85
server/lib/private/billing/features.ts
Normal file
85
server/lib/private/billing/features.ts
Normal file
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import Stripe from "stripe";
|
||||
|
||||
export enum FeatureId {
|
||||
SITE_UPTIME = "siteUptime",
|
||||
USERS = "users",
|
||||
EGRESS_DATA_MB = "egressDataMb",
|
||||
DOMAINS = "domains",
|
||||
REMOTE_EXIT_NODES = "remoteExitNodes"
|
||||
}
|
||||
|
||||
export const FeatureMeterIds: Record<FeatureId, string> = {
|
||||
[FeatureId.SITE_UPTIME]: "mtr_61Srrej5wUJuiTWgo41D3Ee2Ir7WmDLU",
|
||||
[FeatureId.USERS]: "mtr_61SrreISyIWpwUNGR41D3Ee2Ir7WmQro",
|
||||
[FeatureId.EGRESS_DATA_MB]: "mtr_61Srreh9eWrExDSCe41D3Ee2Ir7Wm5YW",
|
||||
[FeatureId.DOMAINS]: "mtr_61Ss9nIKDNMw0LDRU41D3Ee2Ir7WmRPU",
|
||||
[FeatureId.REMOTE_EXIT_NODES]: "mtr_61T86UXnfxTVXy9sD41D3Ee2Ir7WmFTE"
|
||||
};
|
||||
|
||||
export const FeatureMeterIdsSandbox: Record<FeatureId, string> = {
|
||||
[FeatureId.SITE_UPTIME]: "mtr_test_61Snh3cees4w60gv841DCpkOb237BDEu",
|
||||
[FeatureId.USERS]: "mtr_test_61Sn5fLtq1gSfRkyA41DCpkOb237B6au",
|
||||
[FeatureId.EGRESS_DATA_MB]: "mtr_test_61Snh2a2m6qome5Kv41DCpkOb237B3dQ",
|
||||
[FeatureId.DOMAINS]: "mtr_test_61SsA8qrdAlgPpFRQ41DCpkOb237BGts",
|
||||
[FeatureId.REMOTE_EXIT_NODES]: "mtr_test_61T86Vqmwa3D9ra3341DCpkOb237B94K"
|
||||
};
|
||||
|
||||
export function getFeatureMeterId(featureId: FeatureId): string {
|
||||
if (process.env.ENVIRONMENT == "prod" && process.env.SANDBOX_MODE !== "true") {
|
||||
return FeatureMeterIds[featureId];
|
||||
} else {
|
||||
return FeatureMeterIdsSandbox[featureId];
|
||||
}
|
||||
}
|
||||
|
||||
export function getFeatureIdByMetricId(metricId: string): FeatureId | undefined {
|
||||
return (Object.entries(FeatureMeterIds) as [FeatureId, string][])
|
||||
.find(([_, v]) => v === metricId)?.[0];
|
||||
}
|
||||
|
||||
export type FeaturePriceSet = {
|
||||
[key in FeatureId]: string;
|
||||
};
|
||||
|
||||
export const standardFeaturePriceSet: FeaturePriceSet = { // Free tier matches the freeLimitSet
|
||||
[FeatureId.SITE_UPTIME]: "price_1RrQc4D3Ee2Ir7WmaJGZ3MtF",
|
||||
[FeatureId.USERS]: "price_1RrQeJD3Ee2Ir7WmgveP3xea",
|
||||
[FeatureId.EGRESS_DATA_MB]: "price_1RrQXFD3Ee2Ir7WmvGDlgxQk",
|
||||
[FeatureId.DOMAINS]: "price_1Rz3tMD3Ee2Ir7Wm5qLeASzC",
|
||||
[FeatureId.REMOTE_EXIT_NODES]: "price_1S46weD3Ee2Ir7Wm94KEHI4h"
|
||||
};
|
||||
|
||||
export const standardFeaturePriceSetSandbox: FeaturePriceSet = { // Free tier matches the freeLimitSet
|
||||
[FeatureId.SITE_UPTIME]: "price_1RefFBDCpkOb237BPrKZ8IEU",
|
||||
[FeatureId.USERS]: "price_1ReNa4DCpkOb237Bc67G5muF",
|
||||
[FeatureId.EGRESS_DATA_MB]: "price_1Rfp9LDCpkOb237BwuN5Oiu0",
|
||||
[FeatureId.DOMAINS]: "price_1Ryi88DCpkOb237B2D6DM80b",
|
||||
[FeatureId.REMOTE_EXIT_NODES]: "price_1RyiZvDCpkOb237BXpmoIYJL"
|
||||
};
|
||||
|
||||
export function getStandardFeaturePriceSet(): FeaturePriceSet {
|
||||
if (process.env.ENVIRONMENT == "prod" && process.env.SANDBOX_MODE !== "true") {
|
||||
return standardFeaturePriceSet;
|
||||
} else {
|
||||
return standardFeaturePriceSetSandbox;
|
||||
}
|
||||
}
|
||||
|
||||
export function getLineItems(featurePriceSet: FeaturePriceSet): Stripe.Checkout.SessionCreateParams.LineItem[] {
|
||||
return Object.entries(featurePriceSet).map(([featureId, priceId]) => ({
|
||||
price: priceId,
|
||||
}));
|
||||
}
|
||||
16
server/lib/private/billing/index.ts
Normal file
16
server/lib/private/billing/index.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./limitSet";
|
||||
export * from "./features";
|
||||
export * from "./limitsService";
|
||||
63
server/lib/private/billing/limitSet.ts
Normal file
63
server/lib/private/billing/limitSet.ts
Normal file
@@ -0,0 +1,63 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { FeatureId } from "./features";
|
||||
|
||||
export type LimitSet = {
|
||||
[key in FeatureId]: {
|
||||
value: number | null; // null indicates no limit
|
||||
description?: string;
|
||||
};
|
||||
};
|
||||
|
||||
export const sandboxLimitSet: LimitSet = {
|
||||
[FeatureId.SITE_UPTIME]: { value: 2880, description: "Sandbox limit" }, // 1 site up for 2 days
|
||||
[FeatureId.USERS]: { value: 1, description: "Sandbox limit" },
|
||||
[FeatureId.EGRESS_DATA_MB]: { value: 1000, description: "Sandbox limit" }, // 1 GB
|
||||
[FeatureId.DOMAINS]: { value: 0, description: "Sandbox limit" },
|
||||
[FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" },
|
||||
};
|
||||
|
||||
export const freeLimitSet: LimitSet = {
|
||||
[FeatureId.SITE_UPTIME]: { value: 46080, description: "Free tier limit" }, // 1 site up for 32 days
|
||||
[FeatureId.USERS]: { value: 3, description: "Free tier limit" },
|
||||
[FeatureId.EGRESS_DATA_MB]: {
|
||||
value: 25000,
|
||||
description: "Free tier limit"
|
||||
}, // 25 GB
|
||||
[FeatureId.DOMAINS]: { value: 3, description: "Free tier limit" },
|
||||
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Free tier limit" }
|
||||
};
|
||||
|
||||
export const subscribedLimitSet: LimitSet = {
|
||||
[FeatureId.SITE_UPTIME]: {
|
||||
value: 2232000,
|
||||
description: "Contact us to increase soft limit.",
|
||||
}, // 50 sites up for 31 days
|
||||
[FeatureId.USERS]: {
|
||||
value: 150,
|
||||
description: "Contact us to increase soft limit."
|
||||
},
|
||||
[FeatureId.EGRESS_DATA_MB]: {
|
||||
value: 12000000,
|
||||
description: "Contact us to increase soft limit."
|
||||
}, // 12000 GB
|
||||
[FeatureId.DOMAINS]: {
|
||||
value: 20,
|
||||
description: "Contact us to increase soft limit."
|
||||
},
|
||||
[FeatureId.REMOTE_EXIT_NODES]: {
|
||||
value: 5,
|
||||
description: "Contact us to increase soft limit."
|
||||
}
|
||||
};
|
||||
51
server/lib/private/billing/limitsService.ts
Normal file
51
server/lib/private/billing/limitsService.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { db, limits } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { LimitSet } from "./limitSet";
|
||||
import { FeatureId } from "./features";
|
||||
|
||||
class LimitService {
|
||||
async applyLimitSetToOrg(orgId: string, limitSet: LimitSet): Promise<void> {
|
||||
const limitEntries = Object.entries(limitSet);
|
||||
|
||||
// delete existing limits for the org
|
||||
await db.transaction(async (trx) => {
|
||||
await trx.delete(limits).where(eq(limits.orgId, orgId));
|
||||
for (const [featureId, entry] of limitEntries) {
|
||||
const limitId = `${orgId}-${featureId}`;
|
||||
const { value, description } = entry;
|
||||
await trx
|
||||
.insert(limits)
|
||||
.values({ limitId, orgId, featureId, value, description });
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async getOrgLimit(
|
||||
orgId: string,
|
||||
featureId: FeatureId
|
||||
): Promise<number | null> {
|
||||
const limitId = `${orgId}-${featureId}`;
|
||||
const [limit] = await db
|
||||
.select()
|
||||
.from(limits)
|
||||
.where(and(eq(limits.limitId, limitId)))
|
||||
.limit(1);
|
||||
|
||||
return limit ? limit.value : null;
|
||||
}
|
||||
}
|
||||
|
||||
export const limitsService = new LimitService();
|
||||
37
server/lib/private/billing/tiers.ts
Normal file
37
server/lib/private/billing/tiers.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export enum TierId {
|
||||
STANDARD = "standard",
|
||||
}
|
||||
|
||||
export type TierPriceSet = {
|
||||
[key in TierId]: string;
|
||||
};
|
||||
|
||||
export const tierPriceSet: TierPriceSet = { // Free tier matches the freeLimitSet
|
||||
[TierId.STANDARD]: "price_1RrQ9cD3Ee2Ir7Wmqdy3KBa0",
|
||||
};
|
||||
|
||||
export const tierPriceSetSandbox: TierPriceSet = { // Free tier matches the freeLimitSet
|
||||
// when matching tier the keys closer to 0 index are matched first so list the tiers in descending order of value
|
||||
[TierId.STANDARD]: "price_1RrAYJDCpkOb237By2s1P32m",
|
||||
};
|
||||
|
||||
export function getTierPriceSet(environment?: string, sandbox_mode?: boolean): TierPriceSet {
|
||||
if ((process.env.ENVIRONMENT == "prod" && process.env.SANDBOX_MODE !== "true") || (environment === "prod" && sandbox_mode !== true)) { // THIS GETS LOADED CLIENT SIDE AND SERVER SIDE
|
||||
return tierPriceSet;
|
||||
} else {
|
||||
return tierPriceSetSandbox;
|
||||
}
|
||||
}
|
||||
889
server/lib/private/billing/usageService.ts
Normal file
889
server/lib/private/billing/usageService.ts
Normal file
@@ -0,0 +1,889 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { eq, sql, and } from "drizzle-orm";
|
||||
import NodeCache from "node-cache";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { PutObjectCommand } from "@aws-sdk/client-s3";
|
||||
import { s3Client } from "../s3";
|
||||
import * as fs from "fs/promises";
|
||||
import * as path from "path";
|
||||
import {
|
||||
db,
|
||||
usage,
|
||||
customers,
|
||||
sites,
|
||||
newts,
|
||||
limits,
|
||||
Usage,
|
||||
Limit,
|
||||
Transaction
|
||||
} from "@server/db";
|
||||
import { FeatureId, getFeatureMeterId } from "./features";
|
||||
import config from "@server/lib/config";
|
||||
import logger from "@server/logger";
|
||||
import { sendToClient } from "@server/routers/ws";
|
||||
import { build } from "@server/build";
|
||||
|
||||
interface StripeEvent {
|
||||
identifier?: string;
|
||||
timestamp: number;
|
||||
event_name: string;
|
||||
payload: {
|
||||
value: number;
|
||||
stripe_customer_id: string;
|
||||
};
|
||||
}
|
||||
|
||||
export class UsageService {
|
||||
private cache: NodeCache;
|
||||
private bucketName: string | undefined;
|
||||
private currentEventFile: string | null = null;
|
||||
private currentFileStartTime: number = 0;
|
||||
private eventsDir: string | undefined;
|
||||
private uploadingFiles: Set<string> = new Set();
|
||||
|
||||
constructor() {
|
||||
this.cache = new NodeCache({ stdTTL: 300 }); // 5 minute TTL
|
||||
if (build !== "saas") {
|
||||
return;
|
||||
}
|
||||
this.bucketName = config.getRawPrivateConfig().stripe?.s3Bucket;
|
||||
this.eventsDir = config.getRawPrivateConfig().stripe?.localFilePath;
|
||||
|
||||
// Ensure events directory exists
|
||||
this.initializeEventsDirectory().then(() => {
|
||||
this.uploadPendingEventFilesOnStartup();
|
||||
});
|
||||
|
||||
// Periodically check for old event files to upload
|
||||
setInterval(() => {
|
||||
this.uploadOldEventFiles().catch((err) => {
|
||||
logger.error("Error in periodic event file upload:", err);
|
||||
});
|
||||
}, 30000); // every 30 seconds
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate a number to 11 decimal places to prevent precision issues
|
||||
*/
|
||||
private truncateValue(value: number): number {
|
||||
return Math.round(value * 100000000000) / 100000000000; // 11 decimal places
|
||||
}
|
||||
|
||||
private async initializeEventsDirectory(): Promise<void> {
|
||||
if (!this.eventsDir) {
|
||||
logger.warn("Stripe local file path is not configured, skipping events directory initialization.");
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await fs.mkdir(this.eventsDir, { recursive: true });
|
||||
} catch (error) {
|
||||
logger.error("Failed to create events directory:", error);
|
||||
}
|
||||
}
|
||||
|
||||
private async uploadPendingEventFilesOnStartup(): Promise<void> {
|
||||
if (!this.eventsDir || !this.bucketName) {
|
||||
logger.warn("Stripe local file path or bucket name is not configured, skipping leftover event file upload.");
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const files = await fs.readdir(this.eventsDir);
|
||||
for (const file of files) {
|
||||
if (file.endsWith(".json")) {
|
||||
const filePath = path.join(this.eventsDir, file);
|
||||
try {
|
||||
const fileContent = await fs.readFile(
|
||||
filePath,
|
||||
"utf-8"
|
||||
);
|
||||
const events = JSON.parse(fileContent);
|
||||
if (Array.isArray(events) && events.length > 0) {
|
||||
// Upload to S3
|
||||
const uploadCommand = new PutObjectCommand({
|
||||
Bucket: this.bucketName,
|
||||
Key: file,
|
||||
Body: fileContent,
|
||||
ContentType: "application/json"
|
||||
});
|
||||
await s3Client.send(uploadCommand);
|
||||
|
||||
// Check if file still exists before unlinking
|
||||
try {
|
||||
await fs.access(filePath);
|
||||
await fs.unlink(filePath);
|
||||
} catch (unlinkError) {
|
||||
logger.debug(`Startup file ${file} was already deleted`);
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Uploaded leftover event file ${file} to S3 with ${events.length} events`
|
||||
);
|
||||
} else {
|
||||
// Remove empty file
|
||||
try {
|
||||
await fs.access(filePath);
|
||||
await fs.unlink(filePath);
|
||||
} catch (unlinkError) {
|
||||
logger.debug(`Empty startup file ${file} was already deleted`);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`Error processing leftover event file ${file}:`,
|
||||
err
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error("Failed to scan for leftover event files:", err);
|
||||
}
|
||||
}
|
||||
|
||||
public async add(
|
||||
orgId: string,
|
||||
featureId: FeatureId,
|
||||
value: number,
|
||||
transaction: any = null
|
||||
): Promise<Usage | null> {
|
||||
if (build !== "saas") {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Truncate value to 11 decimal places
|
||||
value = this.truncateValue(value);
|
||||
|
||||
// Implement retry logic for deadlock handling
|
||||
const maxRetries = 3;
|
||||
let attempt = 0;
|
||||
|
||||
while (attempt <= maxRetries) {
|
||||
try {
|
||||
// Get subscription data for this org (with caching)
|
||||
const customerId = await this.getCustomerId(orgId, featureId);
|
||||
|
||||
if (!customerId) {
|
||||
logger.warn(
|
||||
`No subscription data found for org ${orgId} and feature ${featureId}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
let usage;
|
||||
if (transaction) {
|
||||
usage = await this.internalAddUsage(
|
||||
orgId,
|
||||
featureId,
|
||||
value,
|
||||
transaction
|
||||
);
|
||||
} else {
|
||||
await db.transaction(async (trx) => {
|
||||
usage = await this.internalAddUsage(orgId, featureId, value, trx);
|
||||
});
|
||||
}
|
||||
|
||||
// Log event for Stripe
|
||||
await this.logStripeEvent(featureId, value, customerId);
|
||||
|
||||
return usage || null;
|
||||
} catch (error: any) {
|
||||
// Check if this is a deadlock error
|
||||
const isDeadlock = error?.code === '40P01' ||
|
||||
error?.cause?.code === '40P01' ||
|
||||
(error?.message && error.message.includes('deadlock'));
|
||||
|
||||
if (isDeadlock && attempt < maxRetries) {
|
||||
attempt++;
|
||||
// Exponential backoff with jitter: 50-150ms, 100-300ms, 200-600ms
|
||||
const baseDelay = Math.pow(2, attempt - 1) * 50;
|
||||
const jitter = Math.random() * baseDelay;
|
||||
const delay = baseDelay + jitter;
|
||||
|
||||
logger.warn(
|
||||
`Deadlock detected for ${orgId}/${featureId}, retrying attempt ${attempt}/${maxRetries} after ${delay.toFixed(0)}ms`
|
||||
);
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, delay));
|
||||
continue;
|
||||
}
|
||||
|
||||
logger.error(
|
||||
`Failed to add usage for ${orgId}/${featureId} after ${attempt} attempts:`,
|
||||
error
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private async internalAddUsage(
|
||||
orgId: string,
|
||||
featureId: FeatureId,
|
||||
value: number,
|
||||
trx: Transaction
|
||||
): Promise<Usage> {
|
||||
// Truncate value to 11 decimal places
|
||||
value = this.truncateValue(value);
|
||||
|
||||
const usageId = `${orgId}-${featureId}`;
|
||||
const meterId = getFeatureMeterId(featureId);
|
||||
|
||||
// Use upsert: insert if not exists, otherwise increment
|
||||
const [returnUsage] = await trx
|
||||
.insert(usage)
|
||||
.values({
|
||||
usageId,
|
||||
featureId,
|
||||
orgId,
|
||||
meterId,
|
||||
latestValue: value,
|
||||
updatedAt: Math.floor(Date.now() / 1000)
|
||||
})
|
||||
.onConflictDoUpdate({
|
||||
target: usage.usageId,
|
||||
set: {
|
||||
latestValue: sql`${usage.latestValue} + ${value}`
|
||||
}
|
||||
}).returning();
|
||||
|
||||
return returnUsage;
|
||||
}
|
||||
|
||||
// Helper function to get today's date as string (YYYY-MM-DD)
|
||||
getTodayDateString(): string {
|
||||
return new Date().toISOString().split("T")[0];
|
||||
}
|
||||
|
||||
// Helper function to get date string from Date object
|
||||
getDateString(date: number): string {
|
||||
return new Date(date * 1000).toISOString().split("T")[0];
|
||||
}
|
||||
|
||||
async updateDaily(
|
||||
orgId: string,
|
||||
featureId: FeatureId,
|
||||
value?: number,
|
||||
customerId?: string
|
||||
): Promise<void> {
|
||||
if (build !== "saas") {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (!customerId) {
|
||||
customerId =
|
||||
(await this.getCustomerId(orgId, featureId)) || undefined;
|
||||
if (!customerId) {
|
||||
logger.warn(
|
||||
`No subscription data found for org ${orgId} and feature ${featureId}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Truncate value to 11 decimal places if provided
|
||||
if (value !== undefined && value !== null) {
|
||||
value = this.truncateValue(value);
|
||||
}
|
||||
|
||||
const today = this.getTodayDateString();
|
||||
|
||||
let currentUsage: Usage | null = null;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
// Get existing meter record
|
||||
const usageId = `${orgId}-${featureId}`;
|
||||
// Get current usage record
|
||||
[currentUsage] = await trx
|
||||
.select()
|
||||
.from(usage)
|
||||
.where(eq(usage.usageId, usageId))
|
||||
.limit(1);
|
||||
|
||||
if (currentUsage) {
|
||||
const lastUpdateDate = this.getDateString(
|
||||
currentUsage.updatedAt
|
||||
);
|
||||
const currentRunningTotal = currentUsage.latestValue;
|
||||
const lastDailyValue = currentUsage.instantaneousValue || 0;
|
||||
|
||||
if (value == undefined || value === null) {
|
||||
value = currentUsage.instantaneousValue || 0;
|
||||
}
|
||||
|
||||
if (lastUpdateDate === today) {
|
||||
// Same day update: replace the daily value
|
||||
// Remove old daily value from running total, add new value
|
||||
const newRunningTotal = this.truncateValue(
|
||||
currentRunningTotal - lastDailyValue + value
|
||||
);
|
||||
|
||||
await trx
|
||||
.update(usage)
|
||||
.set({
|
||||
latestValue: newRunningTotal,
|
||||
instantaneousValue: value,
|
||||
updatedAt: Math.floor(Date.now() / 1000)
|
||||
})
|
||||
.where(eq(usage.usageId, usageId));
|
||||
} else {
|
||||
// New day: add to running total
|
||||
const newRunningTotal = this.truncateValue(
|
||||
currentRunningTotal + value
|
||||
);
|
||||
|
||||
await trx
|
||||
.update(usage)
|
||||
.set({
|
||||
latestValue: newRunningTotal,
|
||||
instantaneousValue: value,
|
||||
updatedAt: Math.floor(Date.now() / 1000)
|
||||
})
|
||||
.where(eq(usage.usageId, usageId));
|
||||
}
|
||||
} else {
|
||||
// First record for this meter
|
||||
const meterId = getFeatureMeterId(featureId);
|
||||
const truncatedValue = this.truncateValue(value || 0);
|
||||
await trx.insert(usage).values({
|
||||
usageId,
|
||||
featureId,
|
||||
orgId,
|
||||
meterId,
|
||||
instantaneousValue: truncatedValue,
|
||||
latestValue: truncatedValue,
|
||||
updatedAt: Math.floor(Date.now() / 1000)
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
await this.logStripeEvent(featureId, value || 0, customerId);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to update daily usage for ${orgId}/${featureId}:`,
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private async getCustomerId(
|
||||
orgId: string,
|
||||
featureId: FeatureId
|
||||
): Promise<string | null> {
|
||||
const cacheKey = `customer_${orgId}_${featureId}`;
|
||||
const cached = this.cache.get<string>(cacheKey);
|
||||
|
||||
if (cached) {
|
||||
return cached;
|
||||
}
|
||||
|
||||
try {
|
||||
// Query subscription data
|
||||
const [customer] = await db
|
||||
.select({
|
||||
customerId: customers.customerId
|
||||
})
|
||||
.from(customers)
|
||||
.where(eq(customers.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
if (!customer) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const customerId = customer.customerId;
|
||||
|
||||
// Cache the result
|
||||
this.cache.set(cacheKey, customerId);
|
||||
|
||||
return customerId;
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to get subscription data for ${orgId}/${featureId}:`,
|
||||
error
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private async logStripeEvent(
|
||||
featureId: FeatureId,
|
||||
value: number,
|
||||
customerId: string
|
||||
): Promise<void> {
|
||||
// Truncate value to 11 decimal places before sending to Stripe
|
||||
const truncatedValue = this.truncateValue(value);
|
||||
|
||||
const event: StripeEvent = {
|
||||
identifier: uuidv4(),
|
||||
timestamp: Math.floor(new Date().getTime() / 1000),
|
||||
event_name: featureId,
|
||||
payload: {
|
||||
value: truncatedValue,
|
||||
stripe_customer_id: customerId
|
||||
}
|
||||
};
|
||||
|
||||
await this.writeEventToFile(event);
|
||||
await this.checkAndUploadFile();
|
||||
}
|
||||
|
||||
private async writeEventToFile(event: StripeEvent): Promise<void> {
|
||||
if (!this.eventsDir || !this.bucketName) {
|
||||
logger.warn("Stripe local file path or bucket name is not configured, skipping event file write.");
|
||||
return;
|
||||
}
|
||||
if (!this.currentEventFile) {
|
||||
this.currentEventFile = this.generateEventFileName();
|
||||
this.currentFileStartTime = Date.now();
|
||||
}
|
||||
|
||||
const filePath = path.join(this.eventsDir, this.currentEventFile);
|
||||
|
||||
try {
|
||||
let events: StripeEvent[] = [];
|
||||
|
||||
// Try to read existing file
|
||||
try {
|
||||
const fileContent = await fs.readFile(filePath, "utf-8");
|
||||
events = JSON.parse(fileContent);
|
||||
} catch (error) {
|
||||
// File doesn't exist or is empty, start with empty array
|
||||
events = [];
|
||||
}
|
||||
|
||||
// Add new event
|
||||
events.push(event);
|
||||
|
||||
// Write back to file
|
||||
await fs.writeFile(filePath, JSON.stringify(events, null, 2));
|
||||
} catch (error) {
|
||||
logger.error("Failed to write event to file:", error);
|
||||
}
|
||||
}
|
||||
|
||||
private async checkAndUploadFile(): Promise<void> {
|
||||
if (!this.currentEventFile) {
|
||||
return;
|
||||
}
|
||||
|
||||
const now = Date.now();
|
||||
const fileAge = now - this.currentFileStartTime;
|
||||
|
||||
// Check if file is at least 1 minute old
|
||||
if (fileAge >= 60000) {
|
||||
// 60 seconds
|
||||
await this.uploadFileToS3();
|
||||
}
|
||||
}
|
||||
|
||||
private async uploadFileToS3(): Promise<void> {
|
||||
if (!this.bucketName || !this.eventsDir) {
|
||||
logger.warn("Stripe local file path or bucket name is not configured, skipping S3 upload.");
|
||||
return;
|
||||
}
|
||||
if (!this.currentEventFile) {
|
||||
return;
|
||||
}
|
||||
|
||||
const fileName = this.currentEventFile;
|
||||
const filePath = path.join(this.eventsDir, fileName);
|
||||
|
||||
// Check if this file is already being uploaded
|
||||
if (this.uploadingFiles.has(fileName)) {
|
||||
logger.debug(`File ${fileName} is already being uploaded, skipping`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Mark file as being uploaded
|
||||
this.uploadingFiles.add(fileName);
|
||||
|
||||
try {
|
||||
// Check if file exists before trying to read it
|
||||
try {
|
||||
await fs.access(filePath);
|
||||
} catch (error) {
|
||||
logger.debug(`File ${fileName} does not exist, may have been already processed`);
|
||||
this.uploadingFiles.delete(fileName);
|
||||
// Reset current file if it was this file
|
||||
if (this.currentEventFile === fileName) {
|
||||
this.currentEventFile = null;
|
||||
this.currentFileStartTime = 0;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if file exists and has content
|
||||
const fileContent = await fs.readFile(filePath, "utf-8");
|
||||
const events = JSON.parse(fileContent);
|
||||
|
||||
if (events.length === 0) {
|
||||
// No events to upload, just clean up
|
||||
try {
|
||||
await fs.unlink(filePath);
|
||||
} catch (unlinkError) {
|
||||
// File may have been already deleted
|
||||
logger.debug(`File ${fileName} was already deleted during cleanup`);
|
||||
}
|
||||
this.currentEventFile = null;
|
||||
this.uploadingFiles.delete(fileName);
|
||||
return;
|
||||
}
|
||||
|
||||
// Upload to S3
|
||||
const uploadCommand = new PutObjectCommand({
|
||||
Bucket: this.bucketName,
|
||||
Key: fileName,
|
||||
Body: fileContent,
|
||||
ContentType: "application/json"
|
||||
});
|
||||
|
||||
await s3Client.send(uploadCommand);
|
||||
|
||||
// Clean up local file - check if it still exists before unlinking
|
||||
try {
|
||||
await fs.access(filePath);
|
||||
await fs.unlink(filePath);
|
||||
} catch (unlinkError) {
|
||||
// File may have been already deleted by another process
|
||||
logger.debug(`File ${fileName} was already deleted during upload`);
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Uploaded ${fileName} to S3 with ${events.length} events`
|
||||
);
|
||||
|
||||
// Reset for next file
|
||||
this.currentEventFile = null;
|
||||
this.currentFileStartTime = 0;
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to upload ${fileName} to S3:`,
|
||||
error
|
||||
);
|
||||
} finally {
|
||||
// Always remove from uploading set
|
||||
this.uploadingFiles.delete(fileName);
|
||||
}
|
||||
}
|
||||
|
||||
private generateEventFileName(): string {
|
||||
const timestamp = new Date().toISOString().replace(/[:.]/g, "-");
|
||||
const uuid = uuidv4().substring(0, 8);
|
||||
return `events-${timestamp}-${uuid}.json`;
|
||||
}
|
||||
|
||||
public async getUsage(
|
||||
orgId: string,
|
||||
featureId: FeatureId
|
||||
): Promise<Usage | null> {
|
||||
if (build !== "saas") {
|
||||
return null;
|
||||
}
|
||||
|
||||
const usageId = `${orgId}-${featureId}`;
|
||||
|
||||
try {
|
||||
const [result] = await db
|
||||
.select()
|
||||
.from(usage)
|
||||
.where(eq(usage.usageId, usageId))
|
||||
.limit(1);
|
||||
|
||||
if (!result) {
|
||||
// Lets create one if it doesn't exist using upsert to handle race conditions
|
||||
logger.info(
|
||||
`Creating new usage record for ${orgId}/${featureId}`
|
||||
);
|
||||
const meterId = getFeatureMeterId(featureId);
|
||||
|
||||
try {
|
||||
const [newUsage] = await db
|
||||
.insert(usage)
|
||||
.values({
|
||||
usageId,
|
||||
featureId,
|
||||
orgId,
|
||||
meterId,
|
||||
latestValue: 0,
|
||||
updatedAt: Math.floor(Date.now() / 1000)
|
||||
})
|
||||
.onConflictDoNothing()
|
||||
.returning();
|
||||
|
||||
if (newUsage) {
|
||||
return newUsage;
|
||||
} else {
|
||||
// Record was created by another process, fetch it
|
||||
const [existingUsage] = await db
|
||||
.select()
|
||||
.from(usage)
|
||||
.where(eq(usage.usageId, usageId))
|
||||
.limit(1);
|
||||
return existingUsage || null;
|
||||
}
|
||||
} catch (insertError) {
|
||||
// Fallback: try to fetch existing record in case of any insert issues
|
||||
logger.warn(
|
||||
`Insert failed for ${orgId}/${featureId}, attempting to fetch existing record:`,
|
||||
insertError
|
||||
);
|
||||
const [existingUsage] = await db
|
||||
.select()
|
||||
.from(usage)
|
||||
.where(eq(usage.usageId, usageId))
|
||||
.limit(1);
|
||||
return existingUsage || null;
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to get usage for ${orgId}/${featureId}:`,
|
||||
error
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
public async getUsageDaily(
|
||||
orgId: string,
|
||||
featureId: FeatureId
|
||||
): Promise<Usage | null> {
|
||||
if (build !== "saas") {
|
||||
return null;
|
||||
}
|
||||
await this.updateDaily(orgId, featureId); // Ensure daily usage is updated
|
||||
return this.getUsage(orgId, featureId);
|
||||
}
|
||||
|
||||
public async forceUpload(): Promise<void> {
|
||||
await this.uploadFileToS3();
|
||||
}
|
||||
|
||||
public clearCache(): void {
|
||||
this.cache.flushAll();
|
||||
}
|
||||
|
||||
/**
|
||||
* Scan the events directory for files older than 1 minute and upload them if not empty.
|
||||
*/
|
||||
private async uploadOldEventFiles(): Promise<void> {
|
||||
if (!this.eventsDir || !this.bucketName) {
|
||||
logger.warn("Stripe local file path or bucket name is not configured, skipping old event file upload.");
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const files = await fs.readdir(this.eventsDir);
|
||||
const now = Date.now();
|
||||
for (const file of files) {
|
||||
if (!file.endsWith(".json")) continue;
|
||||
|
||||
// Skip files that are already being uploaded
|
||||
if (this.uploadingFiles.has(file)) {
|
||||
logger.debug(`Skipping file ${file} as it's already being uploaded`);
|
||||
continue;
|
||||
}
|
||||
|
||||
const filePath = path.join(this.eventsDir, file);
|
||||
|
||||
try {
|
||||
// Check if file still exists before processing
|
||||
try {
|
||||
await fs.access(filePath);
|
||||
} catch (accessError) {
|
||||
logger.debug(`File ${file} does not exist, skipping`);
|
||||
continue;
|
||||
}
|
||||
|
||||
const stat = await fs.stat(filePath);
|
||||
const age = now - stat.mtimeMs;
|
||||
if (age >= 90000) {
|
||||
// 1.5 minutes - Mark as being uploaded
|
||||
this.uploadingFiles.add(file);
|
||||
|
||||
try {
|
||||
const fileContent = await fs.readFile(
|
||||
filePath,
|
||||
"utf-8"
|
||||
);
|
||||
const events = JSON.parse(fileContent);
|
||||
if (Array.isArray(events) && events.length > 0) {
|
||||
// Upload to S3
|
||||
const uploadCommand = new PutObjectCommand({
|
||||
Bucket: this.bucketName,
|
||||
Key: file,
|
||||
Body: fileContent,
|
||||
ContentType: "application/json"
|
||||
});
|
||||
await s3Client.send(uploadCommand);
|
||||
|
||||
// Check if file still exists before unlinking
|
||||
try {
|
||||
await fs.access(filePath);
|
||||
await fs.unlink(filePath);
|
||||
} catch (unlinkError) {
|
||||
logger.debug(`File ${file} was already deleted during interval upload`);
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Interval: Uploaded event file ${file} to S3 with ${events.length} events`
|
||||
);
|
||||
// If this was the current event file, reset it
|
||||
if (this.currentEventFile === file) {
|
||||
this.currentEventFile = null;
|
||||
this.currentFileStartTime = 0;
|
||||
}
|
||||
} else {
|
||||
// Remove empty file
|
||||
try {
|
||||
await fs.access(filePath);
|
||||
await fs.unlink(filePath);
|
||||
} catch (unlinkError) {
|
||||
logger.debug(`Empty file ${file} was already deleted`);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
// Always remove from uploading set
|
||||
this.uploadingFiles.delete(file);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`Interval: Error processing event file ${file}:`,
|
||||
err
|
||||
);
|
||||
// Remove from uploading set on error
|
||||
this.uploadingFiles.delete(file);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error("Interval: Failed to scan for event files:", err);
|
||||
}
|
||||
}
|
||||
|
||||
public async checkLimitSet(orgId: string, kickSites = false, featureId?: FeatureId, usage?: Usage): Promise<boolean> {
|
||||
if (build !== "saas") {
|
||||
return false;
|
||||
}
|
||||
// This method should check the current usage against the limits set for the organization
|
||||
// and kick out all of the sites on the org
|
||||
let hasExceededLimits = false;
|
||||
|
||||
try {
|
||||
let orgLimits: Limit[] = [];
|
||||
if (featureId) {
|
||||
// Get all limits set for this organization
|
||||
orgLimits = await db
|
||||
.select()
|
||||
.from(limits)
|
||||
.where(
|
||||
and(
|
||||
eq(limits.orgId, orgId),
|
||||
eq(limits.featureId, featureId)
|
||||
)
|
||||
);
|
||||
} else {
|
||||
// Get all limits set for this organization
|
||||
orgLimits = await db
|
||||
.select()
|
||||
.from(limits)
|
||||
.where(eq(limits.orgId, orgId));
|
||||
}
|
||||
|
||||
if (orgLimits.length === 0) {
|
||||
logger.debug(`No limits set for org ${orgId}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check each limit against current usage
|
||||
for (const limit of orgLimits) {
|
||||
let currentUsage: Usage | null;
|
||||
if (usage) {
|
||||
currentUsage = usage;
|
||||
} else {
|
||||
currentUsage = await this.getUsage(orgId, limit.featureId as FeatureId);
|
||||
}
|
||||
|
||||
const usageValue = currentUsage?.instantaneousValue || currentUsage?.latestValue || 0;
|
||||
logger.debug(`Current usage for org ${orgId} on feature ${limit.featureId}: ${usageValue}`);
|
||||
logger.debug(`Limit for org ${orgId} on feature ${limit.featureId}: ${limit.value}`);
|
||||
if (currentUsage && limit.value !== null && usageValue > limit.value) {
|
||||
logger.debug(
|
||||
`Org ${orgId} has exceeded limit for ${limit.featureId}: ` +
|
||||
`${usageValue} > ${limit.value}`
|
||||
);
|
||||
hasExceededLimits = true;
|
||||
break; // Exit early if any limit is exceeded
|
||||
}
|
||||
}
|
||||
|
||||
// If any limits are exceeded, disconnect all sites for this organization
|
||||
if (hasExceededLimits && kickSites) {
|
||||
logger.warn(`Disconnecting all sites for org ${orgId} due to exceeded limits`);
|
||||
|
||||
// Get all sites for this organization
|
||||
const orgSites = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.orgId, orgId));
|
||||
|
||||
// Mark all sites as offline and send termination messages
|
||||
const siteUpdates = orgSites.map(site => site.siteId);
|
||||
|
||||
if (siteUpdates.length > 0) {
|
||||
// Send termination messages to newt sites
|
||||
for (const site of orgSites) {
|
||||
if (site.type === "newt") {
|
||||
const [newt] = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, site.siteId))
|
||||
.limit(1);
|
||||
|
||||
if (newt) {
|
||||
const payload = {
|
||||
type: `newt/wg/terminate`,
|
||||
data: {
|
||||
reason: "Usage limits exceeded"
|
||||
}
|
||||
};
|
||||
|
||||
// Don't await to prevent blocking
|
||||
sendToClient(newt.newtId, payload).catch((error: any) => {
|
||||
logger.error(
|
||||
`Failed to send termination message to newt ${newt.newtId}:`,
|
||||
error
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`Disconnected ${orgSites.length} sites for org ${orgId} due to exceeded limits`);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error checking limits for org ${orgId}:`, error);
|
||||
}
|
||||
|
||||
return hasExceededLimits;
|
||||
}
|
||||
}
|
||||
|
||||
export const usageService = new UsageService();
|
||||
206
server/lib/private/createUserAccountOrg.ts
Normal file
206
server/lib/private/createUserAccountOrg.ts
Normal file
@@ -0,0 +1,206 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { isValidCIDR } from "@server/lib/validators";
|
||||
import { getNextAvailableOrgSubnet } from "@server/lib/ip";
|
||||
import {
|
||||
actions,
|
||||
apiKeyOrg,
|
||||
apiKeys,
|
||||
db,
|
||||
domains,
|
||||
Org,
|
||||
orgDomains,
|
||||
orgs,
|
||||
roleActions,
|
||||
roles,
|
||||
userOrgs
|
||||
} from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { defaultRoleAllowedActions } from "@server/routers/role";
|
||||
import { FeatureId, limitsService, sandboxLimitSet } from "@server/lib/private/billing";
|
||||
import { createCustomer } from "@server/routers/private/billing/createCustomer";
|
||||
import { usageService } from "@server/lib/private/billing/usageService";
|
||||
|
||||
export async function createUserAccountOrg(
|
||||
userId: string,
|
||||
userEmail: string
|
||||
): Promise<{
|
||||
success: boolean;
|
||||
org?: {
|
||||
orgId: string;
|
||||
name: string;
|
||||
subnet: string;
|
||||
};
|
||||
error?: string;
|
||||
}> {
|
||||
// const subnet = await getNextAvailableOrgSubnet();
|
||||
const orgId = "org_" + userId;
|
||||
const name = `${userEmail}'s Organization`;
|
||||
|
||||
// if (!isValidCIDR(subnet)) {
|
||||
// return {
|
||||
// success: false,
|
||||
// error: "Invalid subnet format. Please provide a valid CIDR notation."
|
||||
// };
|
||||
// }
|
||||
|
||||
// // make sure the subnet is unique
|
||||
// const subnetExists = await db
|
||||
// .select()
|
||||
// .from(orgs)
|
||||
// .where(eq(orgs.subnet, subnet))
|
||||
// .limit(1);
|
||||
|
||||
// if (subnetExists.length > 0) {
|
||||
// return { success: false, error: `Subnet ${subnet} already exists` };
|
||||
// }
|
||||
|
||||
// make sure the orgId is unique
|
||||
const orgExists = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
if (orgExists.length > 0) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Organization with ID ${orgId} already exists`
|
||||
};
|
||||
}
|
||||
|
||||
let error = "";
|
||||
let org: Org | null = null;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
const allDomains = await trx
|
||||
.select()
|
||||
.from(domains)
|
||||
.where(eq(domains.configManaged, true));
|
||||
|
||||
const newOrg = await trx
|
||||
.insert(orgs)
|
||||
.values({
|
||||
orgId,
|
||||
name,
|
||||
// subnet
|
||||
subnet: "100.90.128.0/24", // TODO: this should not be hardcoded - or can it be the same in all orgs?
|
||||
createdAt: new Date().toISOString()
|
||||
})
|
||||
.returning();
|
||||
|
||||
if (newOrg.length === 0) {
|
||||
error = "Failed to create organization";
|
||||
trx.rollback();
|
||||
return;
|
||||
}
|
||||
|
||||
org = newOrg[0];
|
||||
|
||||
// Create admin role within the same transaction
|
||||
const [insertedRole] = await trx
|
||||
.insert(roles)
|
||||
.values({
|
||||
orgId: newOrg[0].orgId,
|
||||
isAdmin: true,
|
||||
name: "Admin",
|
||||
description: "Admin role with the most permissions"
|
||||
})
|
||||
.returning({ roleId: roles.roleId });
|
||||
|
||||
if (!insertedRole || !insertedRole.roleId) {
|
||||
error = "Failed to create Admin role";
|
||||
trx.rollback();
|
||||
return;
|
||||
}
|
||||
|
||||
const roleId = insertedRole.roleId;
|
||||
|
||||
// Get all actions and create role actions
|
||||
const actionIds = await trx.select().from(actions).execute();
|
||||
|
||||
if (actionIds.length > 0) {
|
||||
await trx.insert(roleActions).values(
|
||||
actionIds.map((action) => ({
|
||||
roleId,
|
||||
actionId: action.actionId,
|
||||
orgId: newOrg[0].orgId
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
if (allDomains.length) {
|
||||
await trx.insert(orgDomains).values(
|
||||
allDomains.map((domain) => ({
|
||||
orgId: newOrg[0].orgId,
|
||||
domainId: domain.domainId
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
await trx.insert(userOrgs).values({
|
||||
userId,
|
||||
orgId: newOrg[0].orgId,
|
||||
roleId: roleId,
|
||||
isOwner: true
|
||||
});
|
||||
|
||||
const memberRole = await trx
|
||||
.insert(roles)
|
||||
.values({
|
||||
name: "Member",
|
||||
description: "Members can only view resources",
|
||||
orgId
|
||||
})
|
||||
.returning();
|
||||
|
||||
await trx.insert(roleActions).values(
|
||||
defaultRoleAllowedActions.map((action) => ({
|
||||
roleId: memberRole[0].roleId,
|
||||
actionId: action,
|
||||
orgId
|
||||
}))
|
||||
);
|
||||
});
|
||||
|
||||
await limitsService.applyLimitSetToOrg(orgId, sandboxLimitSet);
|
||||
|
||||
if (!org) {
|
||||
return { success: false, error: "Failed to create org" };
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return {
|
||||
success: false,
|
||||
error: `Failed to create org: ${error}`
|
||||
};
|
||||
}
|
||||
|
||||
// make sure we have the stripe customer
|
||||
const customerId = await createCustomer(orgId, userEmail);
|
||||
|
||||
if (customerId) {
|
||||
await usageService.updateDaily(orgId, FeatureId.USERS, 1, customerId); // Only 1 because we are crating the org
|
||||
}
|
||||
|
||||
return {
|
||||
org: {
|
||||
orgId,
|
||||
name,
|
||||
// subnet
|
||||
subnet: "100.90.128.0/24"
|
||||
},
|
||||
success: true
|
||||
};
|
||||
}
|
||||
25
server/lib/private/rateLimitStore.ts
Normal file
25
server/lib/private/rateLimitStore.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import RedisStore from "@server/db/private/redisStore";
|
||||
import { MemoryStore, Store } from "express-rate-limit";
|
||||
|
||||
export function createStore(): Store {
|
||||
const rateLimitStore: Store = new RedisStore({
|
||||
prefix: 'api-rate-limit', // Optional: customize Redis key prefix
|
||||
skipFailedRequests: true, // Don't count failed requests
|
||||
skipSuccessfulRequests: false, // Count successful requests
|
||||
});
|
||||
|
||||
return rateLimitStore;
|
||||
}
|
||||
192
server/lib/private/readConfigFile.ts
Normal file
192
server/lib/private/readConfigFile.ts
Normal file
@@ -0,0 +1,192 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import fs from "fs";
|
||||
import yaml from "js-yaml";
|
||||
import { privateConfigFilePath1 } from "@server/lib/consts";
|
||||
import { z } from "zod";
|
||||
import { colorsSchema } from "@server/lib/colorsSchema";
|
||||
import { build } from "@server/build";
|
||||
|
||||
const portSchema = z.number().positive().gt(0).lte(65535);
|
||||
|
||||
export const privateConfigSchema = z
|
||||
.object({
|
||||
app: z.object({
|
||||
region: z.string().optional().default("default"),
|
||||
base_domain: z.string().optional()
|
||||
}).optional().default({
|
||||
region: "default"
|
||||
}),
|
||||
server: z.object({
|
||||
encryption_key_path: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("./config/encryption.pem")
|
||||
.pipe(z.string().min(8)),
|
||||
resend_api_key: z.string().optional(),
|
||||
reo_client_id: z.string().optional(),
|
||||
}).optional().default({
|
||||
encryption_key_path: "./config/encryption.pem"
|
||||
}),
|
||||
redis: z
|
||||
.object({
|
||||
host: z.string(),
|
||||
port: portSchema,
|
||||
password: z.string().optional(),
|
||||
db: z.number().int().nonnegative().optional().default(0),
|
||||
replicas: z
|
||||
.array(
|
||||
z.object({
|
||||
host: z.string(),
|
||||
port: portSchema,
|
||||
password: z.string().optional(),
|
||||
db: z.number().int().nonnegative().optional().default(0)
|
||||
})
|
||||
)
|
||||
.optional()
|
||||
// tls: z
|
||||
// .object({
|
||||
// reject_unauthorized: z
|
||||
// .boolean()
|
||||
// .optional()
|
||||
// .default(true)
|
||||
// })
|
||||
// .optional()
|
||||
})
|
||||
.optional(),
|
||||
gerbil: z
|
||||
.object({
|
||||
local_exit_node_reachable_at: z.string().optional().default("http://gerbil:3003")
|
||||
})
|
||||
.optional()
|
||||
.default({}),
|
||||
flags: z
|
||||
.object({
|
||||
enable_redis: z.boolean().optional(),
|
||||
hide_supporter_key: z.boolean().optional()
|
||||
})
|
||||
.optional(),
|
||||
branding: z
|
||||
.object({
|
||||
app_name: z.string().optional(),
|
||||
background_image_path: z.string().optional(),
|
||||
colors: z
|
||||
.object({
|
||||
light: colorsSchema.optional(),
|
||||
dark: colorsSchema.optional()
|
||||
})
|
||||
.optional(),
|
||||
logo: z
|
||||
.object({
|
||||
light_path: z.string().optional(),
|
||||
dark_path: z.string().optional(),
|
||||
auth_page: z
|
||||
.object({
|
||||
width: z.number().optional(),
|
||||
height: z.number().optional()
|
||||
})
|
||||
.optional(),
|
||||
navbar: z
|
||||
.object({
|
||||
width: z.number().optional(),
|
||||
height: z.number().optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
.optional(),
|
||||
favicon_path: z.string().optional(),
|
||||
footer: z
|
||||
.array(
|
||||
z.object({
|
||||
text: z.string(),
|
||||
href: z.string().optional()
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
login_page: z
|
||||
.object({
|
||||
subtitle_text: z.string().optional(),
|
||||
title_text: z.string().optional()
|
||||
})
|
||||
.optional(),
|
||||
signup_page: z
|
||||
.object({
|
||||
subtitle_text: z.string().optional(),
|
||||
title_text: z.string().optional()
|
||||
})
|
||||
.optional(),
|
||||
resource_auth_page: z
|
||||
.object({
|
||||
show_logo: z.boolean().optional(),
|
||||
hide_powered_by: z.boolean().optional(),
|
||||
title_text: z.string().optional(),
|
||||
subtitle_text: z.string().optional()
|
||||
})
|
||||
.optional(),
|
||||
emails: z
|
||||
.object({
|
||||
signature: z.string().optional(),
|
||||
colors: z
|
||||
.object({
|
||||
primary: z.string().optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
.optional(),
|
||||
stripe: z
|
||||
.object({
|
||||
secret_key: z.string(),
|
||||
webhook_secret: z.string(),
|
||||
s3Bucket: z.string(),
|
||||
s3Region: z.string().default("us-east-1"),
|
||||
localFilePath: z.string()
|
||||
})
|
||||
.optional(),
|
||||
})
|
||||
|
||||
export function readPrivateConfigFile() {
|
||||
if (build == "oss") {
|
||||
return {};
|
||||
}
|
||||
|
||||
const loadConfig = (configPath: string) => {
|
||||
try {
|
||||
const yamlContent = fs.readFileSync(configPath, "utf8");
|
||||
const config = yaml.load(yamlContent);
|
||||
return config;
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new Error(
|
||||
`Error loading configuration file: ${error.message}`
|
||||
);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
let environment: any;
|
||||
if (fs.existsSync(privateConfigFilePath1)) {
|
||||
environment = loadConfig(privateConfigFilePath1);
|
||||
}
|
||||
|
||||
if (!environment) {
|
||||
throw new Error(
|
||||
"No private configuration file found."
|
||||
);
|
||||
}
|
||||
|
||||
return environment;
|
||||
}
|
||||
124
server/lib/private/resend.ts
Normal file
124
server/lib/private/resend.ts
Normal file
@@ -0,0 +1,124 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Resend } from "resend";
|
||||
import config from "../config";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export enum AudienceIds {
|
||||
General = "5cfbf99b-c592-40a9-9b8a-577a4681c158",
|
||||
Subscribed = "870b43fd-387f-44de-8fc1-707335f30b20",
|
||||
Churned = "f3ae92bd-2fdb-4d77-8746-2118afd62549"
|
||||
}
|
||||
|
||||
const resend = new Resend(
|
||||
config.getRawPrivateConfig().server.resend_api_key || "missing"
|
||||
);
|
||||
|
||||
export default resend;
|
||||
|
||||
export async function moveEmailToAudience(
|
||||
email: string,
|
||||
audienceId: AudienceIds
|
||||
) {
|
||||
if (process.env.ENVIRONMENT !== "prod") {
|
||||
logger.debug(`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`);
|
||||
return;
|
||||
}
|
||||
const { error, data } = await retryWithBackoff(async () => {
|
||||
const { data, error } = await resend.contacts.create({
|
||||
email,
|
||||
unsubscribed: false,
|
||||
audienceId
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(
|
||||
`Error adding email ${email} to audience ${audienceId}: ${error}`
|
||||
);
|
||||
}
|
||||
return { error, data };
|
||||
});
|
||||
|
||||
if (error) {
|
||||
logger.error(
|
||||
`Error adding email ${email} to audience ${audienceId}: ${error}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (data) {
|
||||
logger.debug(
|
||||
`Added email ${email} to audience ${audienceId} with contact ID ${data.id}`
|
||||
)
|
||||
}
|
||||
|
||||
const otherAudiences = Object.values(AudienceIds).filter(
|
||||
(id) => id !== audienceId
|
||||
);
|
||||
|
||||
for (const otherAudienceId of otherAudiences) {
|
||||
const { error, data } = await retryWithBackoff(async () => {
|
||||
const { data, error } = await resend.contacts.remove({
|
||||
email,
|
||||
audienceId: otherAudienceId
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(
|
||||
`Error removing email ${email} from audience ${otherAudienceId}: ${error}`
|
||||
);
|
||||
}
|
||||
return { error, data };
|
||||
});
|
||||
|
||||
if (error) {
|
||||
logger.error(
|
||||
`Error removing email ${email} from audience ${otherAudienceId}: ${error}`
|
||||
);
|
||||
}
|
||||
|
||||
if (data) {
|
||||
logger.info(
|
||||
`Removed email ${email} from audience ${otherAudienceId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type RetryOptions = {
|
||||
retries?: number;
|
||||
initialDelayMs?: number;
|
||||
factor?: number;
|
||||
};
|
||||
|
||||
export async function retryWithBackoff<T>(
|
||||
fn: () => Promise<T>,
|
||||
options: RetryOptions = {}
|
||||
): Promise<T> {
|
||||
const { retries = 5, initialDelayMs = 500, factor = 2 } = options;
|
||||
|
||||
let attempt = 0;
|
||||
let delay = initialDelayMs;
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
return await fn();
|
||||
} catch (err) {
|
||||
attempt++;
|
||||
|
||||
if (attempt > retries) throw err;
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
delay *= factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
19
server/lib/private/s3.ts
Normal file
19
server/lib/private/s3.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { S3Client } from "@aws-sdk/client-s3";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
export const s3Client = new S3Client({
|
||||
region: config.getRawPrivateConfig().stripe?.s3Region || "us-east-1",
|
||||
});
|
||||
28
server/lib/private/stripe.ts
Normal file
28
server/lib/private/stripe.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import Stripe from "stripe";
|
||||
import config from "@server/lib/config";
|
||||
import logger from "@server/logger";
|
||||
import { build } from "@server/build";
|
||||
|
||||
let stripe: Stripe | undefined = undefined;
|
||||
if (build == "saas") {
|
||||
const stripeApiKey = config.getRawPrivateConfig().stripe?.secret_key;
|
||||
if (!stripeApiKey) {
|
||||
logger.error("Stripe secret key is not configured");
|
||||
}
|
||||
stripe = new Stripe(stripeApiKey!);
|
||||
}
|
||||
|
||||
export default stripe;
|
||||
@@ -30,7 +30,7 @@ export const configSchema = z
|
||||
anonymous_usage: z.boolean().optional().default(true)
|
||||
})
|
||||
.optional()
|
||||
.default({})
|
||||
.default({}),
|
||||
}).optional().default({
|
||||
log_level: "info",
|
||||
save_logs: false,
|
||||
@@ -130,7 +130,8 @@ export const configSchema = z
|
||||
secret: z
|
||||
.string()
|
||||
.pipe(z.string().min(8))
|
||||
.optional()
|
||||
.optional(),
|
||||
maxmind_db_path: z.string().optional()
|
||||
}).optional().default({
|
||||
integration_port: 3003,
|
||||
external_port: 3000,
|
||||
|
||||
@@ -13,8 +13,8 @@ export async function getValidCertificatesForDomainsHybrid(domains: Set<string>)
|
||||
wildcard: boolean | null;
|
||||
certFile: string | null;
|
||||
keyFile: string | null;
|
||||
expiresAt: Date | null;
|
||||
updatedAt?: Date | null;
|
||||
expiresAt: number | null;
|
||||
updatedAt?: number | null;
|
||||
}>
|
||||
> {
|
||||
if (domains.size === 0) {
|
||||
@@ -72,8 +72,8 @@ export async function getValidCertificatesForDomains(domains: Set<string>): Prom
|
||||
wildcard: boolean | null;
|
||||
certFile: string | null;
|
||||
keyFile: string | null;
|
||||
expiresAt: Date | null;
|
||||
updatedAt?: Date | null;
|
||||
expiresAt: number | null;
|
||||
updatedAt?: number | null;
|
||||
}>
|
||||
> {
|
||||
return []; // stub
|
||||
|
||||
@@ -1 +1,14 @@
|
||||
export * from "./certificates";
|
||||
import { build } from "@server/build";
|
||||
|
||||
// Import both modules
|
||||
import * as certificateModule from "./certificates";
|
||||
import * as privateCertificateModule from "./privateCertificates";
|
||||
|
||||
// Conditionally export Remote Certificates implementation based on build type
|
||||
const remoteCertificatesImplementation = build === "oss" ? certificateModule : privateCertificateModule;
|
||||
|
||||
// Re-export all items from the selected implementation
|
||||
export const {
|
||||
getValidCertificatesForDomains,
|
||||
getValidCertificatesForDomainsHybrid
|
||||
} = remoteCertificatesImplementation;
|
||||
116
server/lib/remoteCertificates/privateCertificates.ts
Normal file
116
server/lib/remoteCertificates/privateCertificates.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import config from "../config";
|
||||
import { certificates, db } from "@server/db";
|
||||
import { and, eq, isNotNull } from "drizzle-orm";
|
||||
import { decryptData } from "../encryption";
|
||||
import * as fs from "fs";
|
||||
|
||||
export async function getValidCertificatesForDomains(
|
||||
domains: Set<string>
|
||||
): Promise<
|
||||
Array<{
|
||||
id: number;
|
||||
domain: string;
|
||||
wildcard: boolean | null;
|
||||
certFile: string | null;
|
||||
keyFile: string | null;
|
||||
expiresAt: number | null;
|
||||
updatedAt?: number | null;
|
||||
}>
|
||||
> {
|
||||
if (domains.size === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const domainArray = Array.from(domains);
|
||||
|
||||
// TODO: add more foreign keys to make this query more efficient - we dont need to keep getting every certificate
|
||||
const validCerts = await db
|
||||
.select({
|
||||
id: certificates.certId,
|
||||
domain: certificates.domain,
|
||||
certFile: certificates.certFile,
|
||||
keyFile: certificates.keyFile,
|
||||
expiresAt: certificates.expiresAt,
|
||||
updatedAt: certificates.updatedAt,
|
||||
wildcard: certificates.wildcard
|
||||
})
|
||||
.from(certificates)
|
||||
.where(
|
||||
and(
|
||||
eq(certificates.status, "valid"),
|
||||
isNotNull(certificates.certFile),
|
||||
isNotNull(certificates.keyFile)
|
||||
)
|
||||
);
|
||||
|
||||
// Filter certificates for the specified domains and if it is a wildcard then you can match on everything up to the first dot
|
||||
const validCertsFiltered = validCerts.filter((cert) => {
|
||||
return (
|
||||
domainArray.includes(cert.domain) ||
|
||||
(cert.wildcard &&
|
||||
domainArray.some((domain) =>
|
||||
domain.endsWith(`.${cert.domain}`)
|
||||
))
|
||||
);
|
||||
});
|
||||
|
||||
const encryptionKeyPath = config.getRawPrivateConfig().server.encryption_key_path;
|
||||
|
||||
if (!fs.existsSync(encryptionKeyPath)) {
|
||||
throw new Error(
|
||||
"Encryption key file not found. Please generate one first."
|
||||
);
|
||||
}
|
||||
|
||||
const encryptionKeyHex = fs.readFileSync(encryptionKeyPath, "utf8").trim();
|
||||
const encryptionKey = Buffer.from(encryptionKeyHex, "hex");
|
||||
|
||||
const validCertsDecrypted = validCertsFiltered.map((cert) => {
|
||||
// Decrypt and save certificate file
|
||||
const decryptedCert = decryptData(
|
||||
cert.certFile!, // is not null from query
|
||||
encryptionKey
|
||||
);
|
||||
|
||||
// Decrypt and save key file
|
||||
const decryptedKey = decryptData(cert.keyFile!, encryptionKey);
|
||||
|
||||
// Return only the certificate data without org information
|
||||
return {
|
||||
...cert,
|
||||
certFile: decryptedCert,
|
||||
keyFile: decryptedKey
|
||||
};
|
||||
});
|
||||
|
||||
return validCertsDecrypted;
|
||||
}
|
||||
|
||||
export async function getValidCertificatesForDomainsHybrid(
|
||||
domains: Set<string>
|
||||
): Promise<
|
||||
Array<{
|
||||
id: number;
|
||||
domain: string;
|
||||
wildcard: boolean | null;
|
||||
certFile: string | null;
|
||||
keyFile: string | null;
|
||||
expiresAt: number | null;
|
||||
updatedAt?: number | null;
|
||||
}>
|
||||
> {
|
||||
return []; // stub
|
||||
}
|
||||
@@ -17,3 +17,12 @@ export const tlsNameSchema = z
|
||||
)
|
||||
.transform((val) => val.toLowerCase());
|
||||
|
||||
export const privateNamespaceSubdomainSchema = z
|
||||
.string()
|
||||
.regex(
|
||||
/^[a-zA-Z0-9-]+$/,
|
||||
"Namespace subdomain can only contain letters, numbers, and hyphens"
|
||||
)
|
||||
.min(1, "Namespace subdomain must be at least 1 character long")
|
||||
.max(32, "Namespace subdomain must be at most 32 characters long")
|
||||
.transform((val) => val.toLowerCase());
|
||||
|
||||
@@ -6,16 +6,15 @@ import * as yaml from "js-yaml";
|
||||
import axios from "axios";
|
||||
import { db, exitNodes } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { tokenManager } from "./tokenManager";
|
||||
import {
|
||||
getCurrentExitNodeId,
|
||||
getTraefikConfig
|
||||
} from "@server/routers/traefik";
|
||||
import { tokenManager } from "../tokenManager";
|
||||
import { getCurrentExitNodeId } from "@server/lib/exitNodes";
|
||||
import { getTraefikConfig } from "./";
|
||||
import {
|
||||
getValidCertificatesForDomains,
|
||||
getValidCertificatesForDomainsHybrid
|
||||
} from "./remoteCertificates";
|
||||
import { sendToExitNode } from "./exitNodeComms";
|
||||
} from "../remoteCertificates";
|
||||
import { sendToExitNode } from "../exitNodes";
|
||||
import { build } from "@server/build";
|
||||
|
||||
export class TraefikConfigManager {
|
||||
private intervalId: NodeJS.Timeout | null = null;
|
||||
@@ -28,8 +27,8 @@ export class TraefikConfigManager {
|
||||
string,
|
||||
{
|
||||
exists: boolean;
|
||||
lastModified: Date | null;
|
||||
expiresAt: Date | null;
|
||||
lastModified: number | null;
|
||||
expiresAt: number | null;
|
||||
wildcard: boolean | null;
|
||||
}
|
||||
>();
|
||||
@@ -115,8 +114,8 @@ export class TraefikConfigManager {
|
||||
string,
|
||||
{
|
||||
exists: boolean;
|
||||
lastModified: Date | null;
|
||||
expiresAt: Date | null;
|
||||
lastModified: number | null;
|
||||
expiresAt: number | null;
|
||||
wildcard: boolean;
|
||||
}
|
||||
>
|
||||
@@ -217,7 +216,12 @@ export class TraefikConfigManager {
|
||||
// Filter out domains covered by wildcard certificates
|
||||
const domainsNeedingCerts = new Set<string>();
|
||||
for (const domain of currentDomains) {
|
||||
if (!isDomainCoveredByWildcard(domain, this.lastLocalCertificateState)) {
|
||||
if (
|
||||
!isDomainCoveredByWildcard(
|
||||
domain,
|
||||
this.lastLocalCertificateState
|
||||
)
|
||||
) {
|
||||
domainsNeedingCerts.add(domain);
|
||||
}
|
||||
}
|
||||
@@ -225,7 +229,12 @@ export class TraefikConfigManager {
|
||||
// Fetch if domains needing certificates have changed
|
||||
const lastDomainsNeedingCerts = new Set<string>();
|
||||
for (const domain of this.lastKnownDomains) {
|
||||
if (!isDomainCoveredByWildcard(domain, this.lastLocalCertificateState)) {
|
||||
if (
|
||||
!isDomainCoveredByWildcard(
|
||||
domain,
|
||||
this.lastLocalCertificateState
|
||||
)
|
||||
) {
|
||||
lastDomainsNeedingCerts.add(domain);
|
||||
}
|
||||
}
|
||||
@@ -255,7 +264,7 @@ export class TraefikConfigManager {
|
||||
// Check if certificate is expiring soon (within 30 days)
|
||||
if (localState.expiresAt) {
|
||||
const daysUntilExpiry =
|
||||
(localState.expiresAt.getTime() - Date.now()) /
|
||||
(localState.expiresAt - Math.floor(Date.now() / 1000)) /
|
||||
(1000 * 60 * 60 * 24);
|
||||
if (daysUntilExpiry < 30) {
|
||||
logger.info(
|
||||
@@ -276,7 +285,7 @@ export class TraefikConfigManager {
|
||||
public async HandleTraefikConfig(): Promise<void> {
|
||||
try {
|
||||
// Get all active domains for this exit node via HTTP call
|
||||
const getTraefikConfig = await this.getTraefikConfig();
|
||||
const getTraefikConfig = await this.internalGetTraefikConfig();
|
||||
|
||||
if (!getTraefikConfig) {
|
||||
logger.error(
|
||||
@@ -315,15 +324,20 @@ export class TraefikConfigManager {
|
||||
wildcard: boolean | null;
|
||||
certFile: string | null;
|
||||
keyFile: string | null;
|
||||
expiresAt: Date | null;
|
||||
updatedAt?: Date | null;
|
||||
expiresAt: number | null;
|
||||
updatedAt?: number | null;
|
||||
}> = [];
|
||||
|
||||
if (this.shouldFetchCertificates(domains)) {
|
||||
// Filter out domains that are already covered by wildcard certificates
|
||||
const domainsToFetch = new Set<string>();
|
||||
for (const domain of domains) {
|
||||
if (!isDomainCoveredByWildcard(domain, this.lastLocalCertificateState)) {
|
||||
if (
|
||||
!isDomainCoveredByWildcard(
|
||||
domain,
|
||||
this.lastLocalCertificateState
|
||||
)
|
||||
) {
|
||||
domainsToFetch.add(domain);
|
||||
} else {
|
||||
logger.debug(
|
||||
@@ -339,7 +353,7 @@ export class TraefikConfigManager {
|
||||
await getValidCertificatesForDomainsHybrid(
|
||||
domainsToFetch
|
||||
);
|
||||
} else {
|
||||
} else {
|
||||
validCertificates =
|
||||
await getValidCertificatesForDomains(
|
||||
domainsToFetch
|
||||
@@ -428,7 +442,7 @@ export class TraefikConfigManager {
|
||||
/**
|
||||
* Get all domains currently in use from traefik config API
|
||||
*/
|
||||
private async getTraefikConfig(): Promise<{
|
||||
private async internalGetTraefikConfig(): Promise<{
|
||||
domains: Set<string>;
|
||||
traefikConfig: any;
|
||||
} | null> {
|
||||
@@ -451,9 +465,13 @@ export class TraefikConfigManager {
|
||||
traefikConfig = resp.data.data;
|
||||
} else {
|
||||
const currentExitNode = await getCurrentExitNodeId();
|
||||
// logger.debug(`Fetching traefik config for exit node: ${currentExitNode}`);
|
||||
traefikConfig = await getTraefikConfig(
|
||||
// this is called by the local exit node to get its own config
|
||||
currentExitNode,
|
||||
config.getRawConfig().traefik.site_types
|
||||
config.getRawConfig().traefik.site_types,
|
||||
build == "oss", // filter out the namespace domains in open source
|
||||
build != "oss" // generate the login pages on the cloud and hybrid
|
||||
);
|
||||
}
|
||||
|
||||
@@ -621,7 +639,8 @@ export class TraefikConfigManager {
|
||||
}
|
||||
|
||||
// If no exact match, check for wildcard certificates that cover this domain
|
||||
for (const [certDomain, certState] of this.lastLocalCertificateState) {
|
||||
for (const [certDomain, certState] of this
|
||||
.lastLocalCertificateState) {
|
||||
if (certState.exists && certState.wildcard) {
|
||||
// Check if this wildcard certificate covers the domain
|
||||
if (domain.endsWith("." + certDomain)) {
|
||||
@@ -671,8 +690,8 @@ export class TraefikConfigManager {
|
||||
wildcard: boolean | null;
|
||||
certFile: string | null;
|
||||
keyFile: string | null;
|
||||
expiresAt: Date | null;
|
||||
updatedAt?: Date | null;
|
||||
expiresAt: number | null;
|
||||
updatedAt?: number | null;
|
||||
}>
|
||||
): Promise<void> {
|
||||
const dynamicConfigPath =
|
||||
@@ -758,7 +777,7 @@ export class TraefikConfigManager {
|
||||
// Update local state tracking
|
||||
this.lastLocalCertificateState.set(cert.domain, {
|
||||
exists: true,
|
||||
lastModified: new Date(),
|
||||
lastModified: Math.floor(Date.now() / 1000),
|
||||
expiresAt: cert.expiresAt,
|
||||
wildcard: cert.wildcard
|
||||
});
|
||||
@@ -800,8 +819,8 @@ export class TraefikConfigManager {
|
||||
cert: {
|
||||
id: number;
|
||||
domain: string;
|
||||
expiresAt: Date | null;
|
||||
updatedAt?: Date | null;
|
||||
expiresAt: number | null;
|
||||
updatedAt?: number | null;
|
||||
},
|
||||
certPath: string,
|
||||
keyPath: string,
|
||||
@@ -818,12 +837,12 @@ export class TraefikConfigManager {
|
||||
}
|
||||
|
||||
// Read last update time from .last_update file
|
||||
let lastUpdateTime: Date | null = null;
|
||||
let lastUpdateTime: number | null = null;
|
||||
try {
|
||||
const lastUpdateStr = fs
|
||||
.readFileSync(lastUpdatePath, "utf8")
|
||||
.trim();
|
||||
lastUpdateTime = new Date(lastUpdateStr);
|
||||
lastUpdateTime = Math.floor(new Date(lastUpdateStr).getTime() / 1000);
|
||||
} catch {
|
||||
lastUpdateTime = null;
|
||||
}
|
||||
@@ -1004,7 +1023,12 @@ export class TraefikConfigManager {
|
||||
|
||||
// Find domains covered by wildcards
|
||||
for (const domain of this.activeDomains) {
|
||||
if (isDomainCoveredByWildcard(domain, this.lastLocalCertificateState)) {
|
||||
if (
|
||||
isDomainCoveredByWildcard(
|
||||
domain,
|
||||
this.lastLocalCertificateState
|
||||
)
|
||||
) {
|
||||
domainsCoveredByWildcards.push(domain);
|
||||
}
|
||||
}
|
||||
@@ -1025,7 +1049,13 @@ export class TraefikConfigManager {
|
||||
/**
|
||||
* Check if a domain is covered by existing wildcard certificates
|
||||
*/
|
||||
export function isDomainCoveredByWildcard(domain: string, lastLocalCertificateState: Map<string, { exists: boolean; wildcard: boolean | null }>): boolean {
|
||||
export function isDomainCoveredByWildcard(
|
||||
domain: string,
|
||||
lastLocalCertificateState: Map<
|
||||
string,
|
||||
{ exists: boolean; wildcard: boolean | null }
|
||||
>
|
||||
): boolean {
|
||||
for (const [certDomain, state] of lastLocalCertificateState) {
|
||||
if (state.exists && state.wildcard) {
|
||||
// If stored as example.com but is wildcard, check subdomains
|
||||
@@ -1,95 +1,14 @@
|
||||
import { Request, Response } from "express";
|
||||
import { db, exitNodes } from "@server/db";
|
||||
import { db, exitNodes, targetHealthCheck } from "@server/db";
|
||||
import { and, eq, inArray, or, isNull, ne, isNotNull } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import config from "@server/lib/config";
|
||||
import { orgs, resources, sites, Target, targets } from "@server/db";
|
||||
import { build } from "@server/build";
|
||||
import createPathRewriteMiddleware from "./middleware";
|
||||
|
||||
let currentExitNodeId: number;
|
||||
const redirectHttpsMiddlewareName = "redirect-to-https";
|
||||
const badgerMiddlewareName = "badger";
|
||||
|
||||
export async function getCurrentExitNodeId(): Promise<number> {
|
||||
if (!currentExitNodeId) {
|
||||
if (config.getRawConfig().gerbil.exit_node_name) {
|
||||
const exitNodeName = config.getRawConfig().gerbil.exit_node_name!;
|
||||
const [exitNode] = await db
|
||||
.select({
|
||||
exitNodeId: exitNodes.exitNodeId
|
||||
})
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.name, exitNodeName));
|
||||
if (exitNode) {
|
||||
currentExitNodeId = exitNode.exitNodeId;
|
||||
}
|
||||
} else {
|
||||
const [exitNode] = await db
|
||||
.select({
|
||||
exitNodeId: exitNodes.exitNodeId
|
||||
})
|
||||
.from(exitNodes)
|
||||
.limit(1);
|
||||
|
||||
if (exitNode) {
|
||||
currentExitNodeId = exitNode.exitNodeId;
|
||||
}
|
||||
}
|
||||
}
|
||||
return currentExitNodeId;
|
||||
}
|
||||
|
||||
export async function traefikConfigProvider(
|
||||
_: Request,
|
||||
res: Response
|
||||
): Promise<any> {
|
||||
try {
|
||||
// First query to get resources with site and org info
|
||||
// Get the current exit node name from config
|
||||
await getCurrentExitNodeId();
|
||||
|
||||
const traefikConfig = await getTraefikConfig(
|
||||
currentExitNodeId,
|
||||
config.getRawConfig().traefik.site_types
|
||||
);
|
||||
|
||||
if (traefikConfig?.http?.middlewares) {
|
||||
// BECAUSE SOMETIMES THE CONFIG CAN BE EMPTY IF THERE IS NOTHING
|
||||
traefikConfig.http.middlewares[badgerMiddlewareName] = {
|
||||
plugin: {
|
||||
[badgerMiddlewareName]: {
|
||||
apiBaseUrl: new URL(
|
||||
"/api/v1",
|
||||
`http://${
|
||||
config.getRawConfig().server.internal_hostname
|
||||
}:${config.getRawConfig().server.internal_port}`
|
||||
).href,
|
||||
userSessionCookieName:
|
||||
config.getRawConfig().server.session_cookie_name,
|
||||
|
||||
// deprecated
|
||||
accessTokenQueryParam:
|
||||
config.getRawConfig().server
|
||||
.resource_access_token_param,
|
||||
|
||||
resourceSessionRequestParam:
|
||||
config.getRawConfig().server
|
||||
.resource_session_request_param
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return res.status(HttpCode.OK).json(traefikConfig);
|
||||
} catch (e) {
|
||||
logger.error(`Failed to build Traefik config: ${e}`);
|
||||
return res.status(HttpCode.INTERNAL_SERVER_ERROR).json({
|
||||
error: "Failed to build Traefik config"
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function validatePathRewriteConfig(
|
||||
path: string | null,
|
||||
@@ -157,140 +76,11 @@ function validatePathRewriteConfig(
|
||||
return { isValid: true };
|
||||
}
|
||||
|
||||
|
||||
function createPathRewriteMiddleware(
|
||||
middlewareName: string,
|
||||
path: string,
|
||||
pathMatchType: string,
|
||||
rewritePath: string,
|
||||
rewritePathType: string
|
||||
): { middlewares: { [key: string]: any }; chain?: string[] } {
|
||||
const middlewares: { [key: string]: any } = {};
|
||||
|
||||
if (pathMatchType !== "regex" && !path.startsWith("/")) {
|
||||
path = `/${path}`;
|
||||
}
|
||||
|
||||
if (rewritePathType !== "regex" && rewritePath !== "" && !rewritePath.startsWith("/")) {
|
||||
rewritePath = `/${rewritePath}`;
|
||||
}
|
||||
|
||||
switch (rewritePathType) {
|
||||
case "exact":
|
||||
// Replace the path with the exact rewrite path
|
||||
let exactPattern = `^${escapeRegex(path)}$`;
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: exactPattern,
|
||||
replacement: rewritePath
|
||||
}
|
||||
};
|
||||
break;
|
||||
|
||||
case "prefix":
|
||||
// Replace matched prefix with new prefix, preserve the rest
|
||||
switch (pathMatchType) {
|
||||
case "prefix":
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: `^${escapeRegex(path)}(.*)`,
|
||||
replacement: `${rewritePath}$1`
|
||||
}
|
||||
};
|
||||
break;
|
||||
case "exact":
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: `^${escapeRegex(path)}$`,
|
||||
replacement: rewritePath
|
||||
}
|
||||
};
|
||||
break;
|
||||
case "regex":
|
||||
// For regex path matching with prefix rewrite, we assume the regex has capture groups
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: path,
|
||||
replacement: rewritePath
|
||||
}
|
||||
};
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
case "regex":
|
||||
// Use advanced regex replacement - works with any match type
|
||||
let regexPattern: string;
|
||||
if (pathMatchType === "regex") {
|
||||
regexPattern = path;
|
||||
} else if (pathMatchType === "prefix") {
|
||||
regexPattern = `^${escapeRegex(path)}(.*)`;
|
||||
} else { // exact
|
||||
regexPattern = `^${escapeRegex(path)}$`;
|
||||
}
|
||||
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: regexPattern,
|
||||
replacement: rewritePath
|
||||
}
|
||||
};
|
||||
break;
|
||||
|
||||
case "stripPrefix":
|
||||
// Strip the matched prefix and optionally add new path
|
||||
if (pathMatchType === "prefix") {
|
||||
middlewares[middlewareName] = {
|
||||
stripPrefix: {
|
||||
prefixes: [path]
|
||||
}
|
||||
};
|
||||
|
||||
// If rewritePath is provided and not empty, add it as a prefix after stripping
|
||||
if (rewritePath && rewritePath !== "" && rewritePath !== "/") {
|
||||
const addPrefixMiddlewareName = `addprefix-${middlewareName.replace('rewrite-', '')}`;
|
||||
middlewares[addPrefixMiddlewareName] = {
|
||||
addPrefix: {
|
||||
prefix: rewritePath
|
||||
}
|
||||
};
|
||||
return {
|
||||
middlewares,
|
||||
chain: [middlewareName, addPrefixMiddlewareName]
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// For exact and regex matches, use replacePathRegex to strip
|
||||
let regexPattern: string;
|
||||
if (pathMatchType === "exact") {
|
||||
regexPattern = `^${escapeRegex(path)}$`;
|
||||
} else if (pathMatchType === "regex") {
|
||||
regexPattern = path;
|
||||
} else {
|
||||
regexPattern = `^${escapeRegex(path)}`;
|
||||
}
|
||||
|
||||
const replacement = rewritePath || "/";
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: regexPattern,
|
||||
replacement: replacement
|
||||
}
|
||||
};
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
logger.error(`Unknown rewritePathType: ${rewritePathType}`);
|
||||
throw new Error(`Unknown rewritePathType: ${rewritePathType}`);
|
||||
}
|
||||
|
||||
return { middlewares };
|
||||
}
|
||||
|
||||
export async function getTraefikConfig(
|
||||
exitNodeId: number,
|
||||
siteTypes: string[]
|
||||
siteTypes: string[],
|
||||
filterOutNamespaceDomains = false,
|
||||
generateLoginPageRouters = false
|
||||
): Promise<any> {
|
||||
// Define extended target type with site information
|
||||
type TargetWithSite = Target & {
|
||||
@@ -329,6 +119,7 @@ export async function getTraefikConfig(
|
||||
method: targets.method,
|
||||
port: targets.port,
|
||||
internalPort: targets.internalPort,
|
||||
hcHealth: targetHealthCheck.hcHealth,
|
||||
path: targets.path,
|
||||
pathMatchType: targets.pathMatchType,
|
||||
rewritePath: targets.rewritePath,
|
||||
@@ -343,11 +134,19 @@ export async function getTraefikConfig(
|
||||
.from(sites)
|
||||
.innerJoin(targets, eq(targets.siteId, sites.siteId))
|
||||
.innerJoin(resources, eq(resources.resourceId, targets.resourceId))
|
||||
.leftJoin(
|
||||
targetHealthCheck,
|
||||
eq(targetHealthCheck.targetId, targets.targetId)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(targets.enabled, true),
|
||||
eq(resources.enabled, true),
|
||||
or(eq(sites.exitNodeId, exitNodeId), isNull(sites.exitNodeId)),
|
||||
or(
|
||||
ne(targetHealthCheck.hcHealth, "unhealthy"), // Exclude unhealthy targets
|
||||
isNull(targetHealthCheck.hcHealth) // Include targets with no health check record
|
||||
),
|
||||
inArray(sites.type, siteTypes),
|
||||
config.getRawConfig().traefik.allow_raw_resources
|
||||
? isNotNull(resources.http) // ignore the http check if allow_raw_resources is true
|
||||
@@ -863,8 +662,4 @@ function sanitizeForMiddlewareName(str: string): string {
|
||||
// Replace any characters that aren't alphanumeric or dash with dash
|
||||
// and remove consecutive dashes
|
||||
return str.replace(/[^a-zA-Z0-9-]/g, '-').replace(/-+/g, '-').replace(/^-|-$/g, '');
|
||||
}
|
||||
|
||||
function escapeRegex(string: string): string {
|
||||
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
}
|
||||
11
server/lib/traefik/index.ts
Normal file
11
server/lib/traefik/index.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { build } from "@server/build";
|
||||
|
||||
// Import both modules
|
||||
import * as traefikModule from "./getTraefikConfig";
|
||||
import * as privateTraefikModule from "./privateGetTraefikConfig";
|
||||
|
||||
// Conditionally export Traefik configuration implementation based on build type
|
||||
const traefikImplementation = build === "oss" ? traefikModule : privateTraefikModule;
|
||||
|
||||
// Re-export all items from the selected implementation
|
||||
export const { getTraefikConfig } = traefikImplementation;
|
||||
140
server/lib/traefik/middleware.ts
Normal file
140
server/lib/traefik/middleware.ts
Normal file
@@ -0,0 +1,140 @@
|
||||
import logger from "@server/logger";
|
||||
|
||||
export default function createPathRewriteMiddleware(
|
||||
middlewareName: string,
|
||||
path: string,
|
||||
pathMatchType: string,
|
||||
rewritePath: string,
|
||||
rewritePathType: string
|
||||
): { middlewares: { [key: string]: any }; chain?: string[] } {
|
||||
const middlewares: { [key: string]: any } = {};
|
||||
|
||||
if (pathMatchType !== "regex" && !path.startsWith("/")) {
|
||||
path = `/${path}`;
|
||||
}
|
||||
|
||||
if (
|
||||
rewritePathType !== "regex" &&
|
||||
rewritePath !== "" &&
|
||||
!rewritePath.startsWith("/")
|
||||
) {
|
||||
rewritePath = `/${rewritePath}`;
|
||||
}
|
||||
|
||||
switch (rewritePathType) {
|
||||
case "exact":
|
||||
// Replace the path with the exact rewrite path
|
||||
let exactPattern = `^${escapeRegex(path)}$`;
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: exactPattern,
|
||||
replacement: rewritePath
|
||||
}
|
||||
};
|
||||
break;
|
||||
|
||||
case "prefix":
|
||||
// Replace matched prefix with new prefix, preserve the rest
|
||||
switch (pathMatchType) {
|
||||
case "prefix":
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: `^${escapeRegex(path)}(.*)`,
|
||||
replacement: `${rewritePath}$1`
|
||||
}
|
||||
};
|
||||
break;
|
||||
case "exact":
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: `^${escapeRegex(path)}$`,
|
||||
replacement: rewritePath
|
||||
}
|
||||
};
|
||||
break;
|
||||
case "regex":
|
||||
// For regex path matching with prefix rewrite, we assume the regex has capture groups
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: path,
|
||||
replacement: rewritePath
|
||||
}
|
||||
};
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
case "regex":
|
||||
// Use advanced regex replacement - works with any match type
|
||||
let regexPattern: string;
|
||||
if (pathMatchType === "regex") {
|
||||
regexPattern = path;
|
||||
} else if (pathMatchType === "prefix") {
|
||||
regexPattern = `^${escapeRegex(path)}(.*)`;
|
||||
} else {
|
||||
// exact
|
||||
regexPattern = `^${escapeRegex(path)}$`;
|
||||
}
|
||||
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: regexPattern,
|
||||
replacement: rewritePath
|
||||
}
|
||||
};
|
||||
break;
|
||||
|
||||
case "stripPrefix":
|
||||
// Strip the matched prefix and optionally add new path
|
||||
if (pathMatchType === "prefix") {
|
||||
middlewares[middlewareName] = {
|
||||
stripPrefix: {
|
||||
prefixes: [path]
|
||||
}
|
||||
};
|
||||
|
||||
// If rewritePath is provided and not empty, add it as a prefix after stripping
|
||||
if (rewritePath && rewritePath !== "" && rewritePath !== "/") {
|
||||
const addPrefixMiddlewareName = `addprefix-${middlewareName.replace("rewrite-", "")}`;
|
||||
middlewares[addPrefixMiddlewareName] = {
|
||||
addPrefix: {
|
||||
prefix: rewritePath
|
||||
}
|
||||
};
|
||||
return {
|
||||
middlewares,
|
||||
chain: [middlewareName, addPrefixMiddlewareName]
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// For exact and regex matches, use replacePathRegex to strip
|
||||
let regexPattern: string;
|
||||
if (pathMatchType === "exact") {
|
||||
regexPattern = `^${escapeRegex(path)}$`;
|
||||
} else if (pathMatchType === "regex") {
|
||||
regexPattern = path;
|
||||
} else {
|
||||
regexPattern = `^${escapeRegex(path)}`;
|
||||
}
|
||||
|
||||
const replacement = rewritePath || "/";
|
||||
middlewares[middlewareName] = {
|
||||
replacePathRegex: {
|
||||
regex: regexPattern,
|
||||
replacement: replacement
|
||||
}
|
||||
};
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
logger.error(`Unknown rewritePathType: ${rewritePathType}`);
|
||||
throw new Error(`Unknown rewritePathType: ${rewritePathType}`);
|
||||
}
|
||||
|
||||
return { middlewares };
|
||||
}
|
||||
|
||||
function escapeRegex(string: string): string {
|
||||
return string.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
|
||||
}
|
||||
692
server/lib/traefik/privateGetTraefikConfig.ts
Normal file
692
server/lib/traefik/privateGetTraefikConfig.ts
Normal file
@@ -0,0 +1,692 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response } from "express";
|
||||
import {
|
||||
certificates,
|
||||
db,
|
||||
domainNamespaces,
|
||||
exitNodes,
|
||||
loginPage,
|
||||
targetHealthCheck
|
||||
} from "@server/db";
|
||||
import { and, eq, inArray, or, isNull, ne, isNotNull } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import config from "@server/lib/config";
|
||||
import { orgs, resources, sites, Target, targets } from "@server/db";
|
||||
import { build } from "@server/build";
|
||||
|
||||
const redirectHttpsMiddlewareName = "redirect-to-https";
|
||||
const redirectToRootMiddlewareName = "redirect-to-root";
|
||||
const badgerMiddlewareName = "badger";
|
||||
|
||||
export async function getTraefikConfig(
|
||||
exitNodeId: number,
|
||||
siteTypes: string[],
|
||||
filterOutNamespaceDomains = false,
|
||||
generateLoginPageRouters = false
|
||||
): Promise<any> {
|
||||
// Define extended target type with site information
|
||||
type TargetWithSite = Target & {
|
||||
site: {
|
||||
siteId: number;
|
||||
type: string;
|
||||
subnet: string | null;
|
||||
exitNodeId: number | null;
|
||||
online: boolean;
|
||||
};
|
||||
};
|
||||
|
||||
// Get resources with their targets and sites in a single optimized query
|
||||
// Start from sites on this exit node, then join to targets and resources
|
||||
const resourcesWithTargetsAndSites = await db
|
||||
.select({
|
||||
// Resource fields
|
||||
resourceId: resources.resourceId,
|
||||
fullDomain: resources.fullDomain,
|
||||
ssl: resources.ssl,
|
||||
http: resources.http,
|
||||
proxyPort: resources.proxyPort,
|
||||
protocol: resources.protocol,
|
||||
subdomain: resources.subdomain,
|
||||
domainId: resources.domainId,
|
||||
enabled: resources.enabled,
|
||||
stickySession: resources.stickySession,
|
||||
tlsServerName: resources.tlsServerName,
|
||||
setHostHeader: resources.setHostHeader,
|
||||
enableProxy: resources.enableProxy,
|
||||
headers: resources.headers,
|
||||
// Target fields
|
||||
targetId: targets.targetId,
|
||||
targetEnabled: targets.enabled,
|
||||
ip: targets.ip,
|
||||
method: targets.method,
|
||||
port: targets.port,
|
||||
internalPort: targets.internalPort,
|
||||
hcHealth: targetHealthCheck.hcHealth,
|
||||
path: targets.path,
|
||||
pathMatchType: targets.pathMatchType,
|
||||
|
||||
// Site fields
|
||||
siteId: sites.siteId,
|
||||
siteType: sites.type,
|
||||
siteOnline: sites.online,
|
||||
subnet: sites.subnet,
|
||||
exitNodeId: sites.exitNodeId,
|
||||
// Namespace
|
||||
domainNamespaceId: domainNamespaces.domainNamespaceId,
|
||||
// Certificate
|
||||
certificateStatus: certificates.status
|
||||
})
|
||||
.from(sites)
|
||||
.innerJoin(targets, eq(targets.siteId, sites.siteId))
|
||||
.innerJoin(resources, eq(resources.resourceId, targets.resourceId))
|
||||
.leftJoin(certificates, eq(certificates.domainId, resources.domainId))
|
||||
.leftJoin(
|
||||
targetHealthCheck,
|
||||
eq(targetHealthCheck.targetId, targets.targetId)
|
||||
)
|
||||
.leftJoin(
|
||||
domainNamespaces,
|
||||
eq(domainNamespaces.domainId, resources.domainId)
|
||||
) // THIS IS CLOUD ONLY TO FILTER OUT THE DOMAIN NAMESPACES IF REQUIRED
|
||||
.where(
|
||||
and(
|
||||
eq(targets.enabled, true),
|
||||
eq(resources.enabled, true),
|
||||
// or(
|
||||
eq(sites.exitNodeId, exitNodeId),
|
||||
// isNull(sites.exitNodeId)
|
||||
// ),
|
||||
or(
|
||||
ne(targetHealthCheck.hcHealth, "unhealthy"), // Exclude unhealthy targets
|
||||
isNull(targetHealthCheck.hcHealth) // Include targets with no health check record
|
||||
),
|
||||
inArray(sites.type, siteTypes),
|
||||
config.getRawConfig().traefik.allow_raw_resources
|
||||
? isNotNull(resources.http) // ignore the http check if allow_raw_resources is true
|
||||
: eq(resources.http, true)
|
||||
)
|
||||
);
|
||||
|
||||
// Group by resource and include targets with their unique site data
|
||||
const resourcesMap = new Map();
|
||||
|
||||
resourcesWithTargetsAndSites.forEach((row) => {
|
||||
const resourceId = row.resourceId;
|
||||
const targetPath = sanitizePath(row.path) || ""; // Handle null/undefined paths
|
||||
const pathMatchType = row.pathMatchType || "";
|
||||
|
||||
if (filterOutNamespaceDomains && row.domainNamespaceId) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Create a unique key combining resourceId and path+pathMatchType
|
||||
const pathKey = [targetPath, pathMatchType].filter(Boolean).join("-");
|
||||
const mapKey = [resourceId, pathKey].filter(Boolean).join("-");
|
||||
|
||||
if (!resourcesMap.has(mapKey)) {
|
||||
resourcesMap.set(mapKey, {
|
||||
resourceId: row.resourceId,
|
||||
fullDomain: row.fullDomain,
|
||||
ssl: row.ssl,
|
||||
http: row.http,
|
||||
proxyPort: row.proxyPort,
|
||||
protocol: row.protocol,
|
||||
subdomain: row.subdomain,
|
||||
domainId: row.domainId,
|
||||
enabled: row.enabled,
|
||||
stickySession: row.stickySession,
|
||||
tlsServerName: row.tlsServerName,
|
||||
setHostHeader: row.setHostHeader,
|
||||
enableProxy: row.enableProxy,
|
||||
certificateStatus: row.certificateStatus,
|
||||
targets: [],
|
||||
headers: row.headers,
|
||||
path: row.path, // the targets will all have the same path
|
||||
pathMatchType: row.pathMatchType // the targets will all have the same pathMatchType
|
||||
});
|
||||
}
|
||||
|
||||
// Add target with its associated site data
|
||||
resourcesMap.get(mapKey).targets.push({
|
||||
resourceId: row.resourceId,
|
||||
targetId: row.targetId,
|
||||
ip: row.ip,
|
||||
method: row.method,
|
||||
port: row.port,
|
||||
internalPort: row.internalPort,
|
||||
enabled: row.targetEnabled,
|
||||
site: {
|
||||
siteId: row.siteId,
|
||||
type: row.siteType,
|
||||
subnet: row.subnet,
|
||||
exitNodeId: row.exitNodeId,
|
||||
online: row.siteOnline
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// make sure we have at least one resource
|
||||
if (resourcesMap.size === 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const config_output: any = {
|
||||
http: {
|
||||
middlewares: {
|
||||
[redirectHttpsMiddlewareName]: {
|
||||
redirectScheme: {
|
||||
scheme: "https"
|
||||
}
|
||||
},
|
||||
[redirectToRootMiddlewareName]: {
|
||||
redirectRegex: {
|
||||
regex: "^(https?)://([^/]+)(/.*)?",
|
||||
replacement: "${1}://${2}/auth/org",
|
||||
permanent: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// get the key and the resource
|
||||
for (const [key, resource] of resourcesMap.entries()) {
|
||||
const targets = resource.targets;
|
||||
|
||||
const routerName = `${key}-router`;
|
||||
const serviceName = `${key}-service`;
|
||||
const fullDomain = `${resource.fullDomain}`;
|
||||
const transportName = `${key}-transport`;
|
||||
const headersMiddlewareName = `${key}-headers-middleware`;
|
||||
|
||||
if (!resource.enabled) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (resource.http) {
|
||||
if (!resource.domainId) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!resource.fullDomain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (resource.certificateStatus !== "valid") {
|
||||
logger.debug(
|
||||
`Resource ${resource.resourceId} has certificate stats ${resource.certificateStats}`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// add routers and services empty objects if they don't exist
|
||||
if (!config_output.http.routers) {
|
||||
config_output.http.routers = {};
|
||||
}
|
||||
|
||||
if (!config_output.http.services) {
|
||||
config_output.http.services = {};
|
||||
}
|
||||
|
||||
const domainParts = fullDomain.split(".");
|
||||
let wildCard;
|
||||
if (domainParts.length <= 2) {
|
||||
wildCard = `*.${domainParts.join(".")}`;
|
||||
} else {
|
||||
wildCard = `*.${domainParts.slice(1).join(".")}`;
|
||||
}
|
||||
|
||||
if (!resource.subdomain) {
|
||||
wildCard = resource.fullDomain;
|
||||
}
|
||||
|
||||
const configDomain = config.getDomain(resource.domainId);
|
||||
|
||||
let certResolver: string, preferWildcardCert: boolean;
|
||||
if (!configDomain) {
|
||||
certResolver = config.getRawConfig().traefik.cert_resolver;
|
||||
preferWildcardCert =
|
||||
config.getRawConfig().traefik.prefer_wildcard_cert;
|
||||
} else {
|
||||
certResolver = configDomain.cert_resolver;
|
||||
preferWildcardCert = configDomain.prefer_wildcard_cert;
|
||||
}
|
||||
|
||||
let tls = {};
|
||||
if (build == "oss") {
|
||||
tls = {
|
||||
certResolver: certResolver,
|
||||
...(preferWildcardCert
|
||||
? {
|
||||
domains: [
|
||||
{
|
||||
main: wildCard
|
||||
}
|
||||
]
|
||||
}
|
||||
: {})
|
||||
};
|
||||
}
|
||||
|
||||
const additionalMiddlewares =
|
||||
config.getRawConfig().traefik.additional_middlewares || [];
|
||||
|
||||
const routerMiddlewares = [
|
||||
badgerMiddlewareName,
|
||||
...additionalMiddlewares
|
||||
];
|
||||
|
||||
if (resource.headers || resource.setHostHeader) {
|
||||
// if there are headers, parse them into an object
|
||||
const headersObj: { [key: string]: string } = {};
|
||||
if (resource.headers) {
|
||||
let headersArr: { name: string; value: string }[] = [];
|
||||
try {
|
||||
headersArr = JSON.parse(resource.headers) as {
|
||||
name: string;
|
||||
value: string;
|
||||
}[];
|
||||
} catch (e) {
|
||||
logger.warn(
|
||||
`Failed to parse headers for resource ${resource.resourceId}: ${e}`
|
||||
);
|
||||
}
|
||||
|
||||
headersArr.forEach((header) => {
|
||||
headersObj[header.name] = header.value;
|
||||
});
|
||||
}
|
||||
|
||||
if (resource.setHostHeader) {
|
||||
headersObj["Host"] = resource.setHostHeader;
|
||||
}
|
||||
|
||||
// check if the object is not empty
|
||||
if (Object.keys(headersObj).length > 0) {
|
||||
// Add the headers middleware
|
||||
if (!config_output.http.middlewares) {
|
||||
config_output.http.middlewares = {};
|
||||
}
|
||||
config_output.http.middlewares[headersMiddlewareName] = {
|
||||
headers: {
|
||||
customRequestHeaders: headersObj
|
||||
}
|
||||
};
|
||||
|
||||
routerMiddlewares.push(headersMiddlewareName);
|
||||
}
|
||||
}
|
||||
|
||||
let rule = `Host(\`${fullDomain}\`)`;
|
||||
let priority = 100;
|
||||
if (resource.path && resource.pathMatchType) {
|
||||
priority += 1;
|
||||
// add path to rule based on match type
|
||||
let path = resource.path;
|
||||
// if the path doesn't start with a /, add it
|
||||
if (!path.startsWith("/")) {
|
||||
path = `/${path}`;
|
||||
}
|
||||
if (resource.pathMatchType === "exact") {
|
||||
rule += ` && Path(\`${path}\`)`;
|
||||
} else if (resource.pathMatchType === "prefix") {
|
||||
rule += ` && PathPrefix(\`${path}\`)`;
|
||||
} else if (resource.pathMatchType === "regex") {
|
||||
rule += ` && PathRegexp(\`${resource.path}\`)`; // this is the raw path because it's a regex
|
||||
}
|
||||
}
|
||||
|
||||
config_output.http.routers![routerName] = {
|
||||
entryPoints: [
|
||||
resource.ssl
|
||||
? config.getRawConfig().traefik.https_entrypoint
|
||||
: config.getRawConfig().traefik.http_entrypoint
|
||||
],
|
||||
middlewares: routerMiddlewares,
|
||||
service: serviceName,
|
||||
rule: rule,
|
||||
priority: priority,
|
||||
...(resource.ssl ? { tls } : {})
|
||||
};
|
||||
|
||||
if (resource.ssl) {
|
||||
config_output.http.routers![routerName + "-redirect"] = {
|
||||
entryPoints: [
|
||||
config.getRawConfig().traefik.http_entrypoint
|
||||
],
|
||||
middlewares: [redirectHttpsMiddlewareName],
|
||||
service: serviceName,
|
||||
rule: rule,
|
||||
priority: priority
|
||||
};
|
||||
}
|
||||
|
||||
config_output.http.services![serviceName] = {
|
||||
loadBalancer: {
|
||||
servers: (() => {
|
||||
// Check if any sites are online
|
||||
// THIS IS SO THAT THERE IS SOME IMMEDIATE FEEDBACK
|
||||
// EVEN IF THE SITES HAVE NOT UPDATED YET FROM THE
|
||||
// RECEIVE BANDWIDTH ENDPOINT.
|
||||
|
||||
// TODO: HOW TO HANDLE ^^^^^^ BETTER
|
||||
const anySitesOnline = (
|
||||
targets as TargetWithSite[]
|
||||
).some((target: TargetWithSite) => target.site.online);
|
||||
|
||||
return (
|
||||
(targets as TargetWithSite[])
|
||||
.filter((target: TargetWithSite) => {
|
||||
if (!target.enabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If any sites are online, exclude offline sites
|
||||
if (anySitesOnline && !target.site.online) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
target.site.type === "local" ||
|
||||
target.site.type === "wireguard"
|
||||
) {
|
||||
if (
|
||||
!target.ip ||
|
||||
!target.port ||
|
||||
!target.method
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
} else if (target.site.type === "newt") {
|
||||
if (
|
||||
!target.internalPort ||
|
||||
!target.method ||
|
||||
!target.site.subnet
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.map((target: TargetWithSite) => {
|
||||
if (
|
||||
target.site.type === "local" ||
|
||||
target.site.type === "wireguard"
|
||||
) {
|
||||
return {
|
||||
url: `${target.method}://${target.ip}:${target.port}`
|
||||
};
|
||||
} else if (target.site.type === "newt") {
|
||||
const ip =
|
||||
target.site.subnet!.split("/")[0];
|
||||
return {
|
||||
url: `${target.method}://${ip}:${target.internalPort}`
|
||||
};
|
||||
}
|
||||
})
|
||||
// filter out duplicates
|
||||
.filter(
|
||||
(v, i, a) =>
|
||||
a.findIndex(
|
||||
(t) => t && v && t.url === v.url
|
||||
) === i
|
||||
)
|
||||
);
|
||||
})(),
|
||||
...(resource.stickySession
|
||||
? {
|
||||
sticky: {
|
||||
cookie: {
|
||||
name: "p_sticky", // TODO: make this configurable via config.yml like other cookies
|
||||
secure: resource.ssl,
|
||||
httpOnly: true
|
||||
}
|
||||
}
|
||||
}
|
||||
: {})
|
||||
}
|
||||
};
|
||||
|
||||
// Add the serversTransport if TLS server name is provided
|
||||
if (resource.tlsServerName) {
|
||||
if (!config_output.http.serversTransports) {
|
||||
config_output.http.serversTransports = {};
|
||||
}
|
||||
config_output.http.serversTransports![transportName] = {
|
||||
serverName: resource.tlsServerName,
|
||||
//unfortunately the following needs to be set. traefik doesn't merge the default serverTransport settings
|
||||
// if defined in the static config and here. if not set, self-signed certs won't work
|
||||
insecureSkipVerify: true
|
||||
};
|
||||
config_output.http.services![
|
||||
serviceName
|
||||
].loadBalancer.serversTransport = transportName;
|
||||
}
|
||||
} else {
|
||||
// Non-HTTP (TCP/UDP) configuration
|
||||
if (!resource.enableProxy) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const protocol = resource.protocol.toLowerCase();
|
||||
const port = resource.proxyPort;
|
||||
|
||||
if (!port) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!config_output[protocol]) {
|
||||
config_output[protocol] = {
|
||||
routers: {},
|
||||
services: {}
|
||||
};
|
||||
}
|
||||
|
||||
config_output[protocol].routers[routerName] = {
|
||||
entryPoints: [`${protocol}-${port}`],
|
||||
service: serviceName,
|
||||
...(protocol === "tcp" ? { rule: "HostSNI(`*`)" } : {})
|
||||
};
|
||||
|
||||
config_output[protocol].services[serviceName] = {
|
||||
loadBalancer: {
|
||||
servers: (() => {
|
||||
// Check if any sites are online
|
||||
const anySitesOnline = (
|
||||
targets as TargetWithSite[]
|
||||
).some((target: TargetWithSite) => target.site.online);
|
||||
|
||||
return (targets as TargetWithSite[])
|
||||
.filter((target: TargetWithSite) => {
|
||||
if (!target.enabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If any sites are online, exclude offline sites
|
||||
if (anySitesOnline && !target.site.online) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
target.site.type === "local" ||
|
||||
target.site.type === "wireguard"
|
||||
) {
|
||||
if (!target.ip || !target.port) {
|
||||
return false;
|
||||
}
|
||||
} else if (target.site.type === "newt") {
|
||||
if (
|
||||
!target.internalPort ||
|
||||
!target.site.subnet
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.map((target: TargetWithSite) => {
|
||||
if (
|
||||
target.site.type === "local" ||
|
||||
target.site.type === "wireguard"
|
||||
) {
|
||||
return {
|
||||
address: `${target.ip}:${target.port}`
|
||||
};
|
||||
} else if (target.site.type === "newt") {
|
||||
const ip =
|
||||
target.site.subnet!.split("/")[0];
|
||||
return {
|
||||
address: `${ip}:${target.internalPort}`
|
||||
};
|
||||
}
|
||||
});
|
||||
})(),
|
||||
...(resource.stickySession
|
||||
? {
|
||||
sticky: {
|
||||
ipStrategy: {
|
||||
depth: 0,
|
||||
sourcePort: true
|
||||
}
|
||||
}
|
||||
}
|
||||
: {})
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (generateLoginPageRouters) {
|
||||
const exitNodeLoginPages = await db
|
||||
.select({
|
||||
loginPageId: loginPage.loginPageId,
|
||||
fullDomain: loginPage.fullDomain,
|
||||
exitNodeId: exitNodes.exitNodeId,
|
||||
domainId: loginPage.domainId,
|
||||
certificateStatus: certificates.status
|
||||
})
|
||||
.from(loginPage)
|
||||
.innerJoin(
|
||||
exitNodes,
|
||||
eq(exitNodes.exitNodeId, loginPage.exitNodeId)
|
||||
)
|
||||
.leftJoin(
|
||||
certificates,
|
||||
eq(certificates.domainId, loginPage.domainId)
|
||||
)
|
||||
.where(eq(exitNodes.exitNodeId, exitNodeId));
|
||||
|
||||
if (exitNodeLoginPages.length > 0) {
|
||||
if (!config_output.http.services) {
|
||||
config_output.http.services = {};
|
||||
}
|
||||
|
||||
if (!config_output.http.services["landing-service"]) {
|
||||
config_output.http.services["landing-service"] = {
|
||||
loadBalancer: {
|
||||
servers: [
|
||||
{
|
||||
url: `http://${
|
||||
config.getRawConfig().server
|
||||
.internal_hostname
|
||||
}:${config.getRawConfig().server.next_port}`
|
||||
}
|
||||
]
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
for (const lp of exitNodeLoginPages) {
|
||||
if (!lp.domainId) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!lp.fullDomain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (lp.certificateStatus !== "valid") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// auth-allowed:
|
||||
// rule: "Host(`auth.pangolin.internal`) && (PathRegexp(`^/auth/resource/[0-9]+$`) || PathPrefix(`/_next`))"
|
||||
// service: next-service
|
||||
// entryPoints:
|
||||
// - websecure
|
||||
|
||||
const routerName = `loginpage-${lp.loginPageId}`;
|
||||
const fullDomain = `${lp.fullDomain}`;
|
||||
|
||||
if (!config_output.http.routers) {
|
||||
config_output.http.routers = {};
|
||||
}
|
||||
|
||||
config_output.http.routers![routerName + "-router"] = {
|
||||
entryPoints: [
|
||||
config.getRawConfig().traefik.https_entrypoint
|
||||
],
|
||||
service: "landing-service",
|
||||
rule: `Host(\`${fullDomain}\`) && (PathRegexp(\`^/auth/resource/[^/]+$\`) || PathRegexp(\`^/auth/idp/[0-9]+/oidc/callback\`) || PathPrefix(\`/_next\`) || Path(\`/auth/org\`) || PathRegexp(\`^/__nextjs*\`))`,
|
||||
priority: 203,
|
||||
tls: {}
|
||||
};
|
||||
|
||||
// auth-catchall:
|
||||
// rule: "Host(`auth.example.com`)"
|
||||
// middlewares:
|
||||
// - redirect-to-root
|
||||
// service: next-service
|
||||
// entryPoints:
|
||||
// - web
|
||||
|
||||
config_output.http.routers![routerName + "-catchall"] = {
|
||||
entryPoints: [
|
||||
config.getRawConfig().traefik.https_entrypoint
|
||||
],
|
||||
middlewares: [redirectToRootMiddlewareName],
|
||||
service: "landing-service",
|
||||
rule: `Host(\`${fullDomain}\`)`,
|
||||
priority: 202,
|
||||
tls: {}
|
||||
};
|
||||
|
||||
// we need to add a redirect from http to https too
|
||||
config_output.http.routers![routerName + "-redirect"] = {
|
||||
entryPoints: [
|
||||
config.getRawConfig().traefik.http_entrypoint
|
||||
],
|
||||
middlewares: [redirectHttpsMiddlewareName],
|
||||
service: "landing-service",
|
||||
rule: `Host(\`${fullDomain}\`)`,
|
||||
priority: 201
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return config_output;
|
||||
}
|
||||
|
||||
function sanitizePath(path: string | null | undefined): string | undefined {
|
||||
if (!path) return undefined;
|
||||
// clean any non alphanumeric characters from the path and replace with dashes
|
||||
// the path cant be too long either, so limit to 50 characters
|
||||
if (path.length > 50) {
|
||||
path = path.substring(0, 50);
|
||||
}
|
||||
return path.replace(/[^a-zA-Z0-9]/g, "");
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { assertEquals } from "@test/assert";
|
||||
import { isDomainCoveredByWildcard } from "./traefikConfig";
|
||||
import { isDomainCoveredByWildcard } from "./TraefikConfigManager";
|
||||
|
||||
function runTests() {
|
||||
console.log('Running wildcard domain coverage tests...');
|
||||
@@ -163,6 +163,26 @@ export function validateHeaders(headers: string): boolean {
|
||||
});
|
||||
}
|
||||
|
||||
export function isSecondLevelDomain(domain: string): boolean {
|
||||
if (!domain || typeof domain !== 'string') {
|
||||
return false;
|
||||
}
|
||||
|
||||
const trimmedDomain = domain.trim().toLowerCase();
|
||||
|
||||
// Split into parts
|
||||
const parts = trimmedDomain.split('.');
|
||||
|
||||
// Should have exactly 2 parts for a second-level domain (e.g., "example.com")
|
||||
if (parts.length !== 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the TLD part is valid
|
||||
const tld = parts[1].toUpperCase();
|
||||
return validTlds.includes(tld);
|
||||
}
|
||||
|
||||
const validTlds = [
|
||||
"AAA",
|
||||
"AARP",
|
||||
|
||||
@@ -27,4 +27,4 @@ export * from "./verifyApiKeyAccess";
|
||||
export * from "./verifyDomainAccess";
|
||||
export * from "./verifyClientsEnabled";
|
||||
export * from "./verifyUserIsOrgOwner";
|
||||
export * from "./verifySiteResourceAccess";
|
||||
export * from "./verifySiteResourceAccess";
|
||||
98
server/middlewares/private/corsWithLoginPage.ts
Normal file
98
server/middlewares/private/corsWithLoginPage.ts
Normal file
@@ -0,0 +1,98 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import cors, { CorsOptions } from "cors";
|
||||
import config from "@server/lib/config";
|
||||
import logger from "@server/logger";
|
||||
import { db, loginPage } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
async function isValidLoginPageDomain(host: string): Promise<boolean> {
|
||||
try {
|
||||
const [result] = await db
|
||||
.select()
|
||||
.from(loginPage)
|
||||
.where(eq(loginPage.fullDomain, host))
|
||||
.limit(1);
|
||||
|
||||
const isValid = !!result;
|
||||
|
||||
return isValid;
|
||||
} catch (error) {
|
||||
logger.error("Error checking loginPage domain:", error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export function corsWithLoginPageSupport(corsConfig: any) {
|
||||
const options = {
|
||||
...(corsConfig?.origins
|
||||
? { origin: corsConfig.origins }
|
||||
: {
|
||||
origin: (origin: any, callback: any) => {
|
||||
callback(null, true);
|
||||
}
|
||||
}),
|
||||
...(corsConfig?.methods && { methods: corsConfig.methods }),
|
||||
...(corsConfig?.allowed_headers && {
|
||||
allowedHeaders: corsConfig.allowed_headers
|
||||
}),
|
||||
credentials: !(corsConfig?.credentials === false)
|
||||
};
|
||||
|
||||
return async (req: Request, res: Response, next: NextFunction) => {
|
||||
const originValidatedCorsConfig = {
|
||||
origin: async (
|
||||
origin: string | undefined,
|
||||
callback: (err: Error | null, allow?: boolean) => void
|
||||
) => {
|
||||
// If no origin (e.g., same-origin request), allow it
|
||||
|
||||
if (!origin) {
|
||||
return callback(null, true);
|
||||
}
|
||||
|
||||
const dashboardUrl = config.getRawConfig().app.dashboard_url;
|
||||
|
||||
// If no dashboard_url is configured, allow all origins
|
||||
if (!dashboardUrl) {
|
||||
return callback(null, true);
|
||||
}
|
||||
|
||||
// Check if origin matches dashboard URL
|
||||
const dashboardHost = new URL(dashboardUrl).host;
|
||||
const originHost = new URL(origin).host;
|
||||
|
||||
if (originHost === dashboardHost) {
|
||||
return callback(null, true);
|
||||
}
|
||||
|
||||
// If origin doesn't match dashboard URL, check if it's a valid loginPage domain
|
||||
const isValidDomain = await isValidLoginPageDomain(originHost);
|
||||
|
||||
if (isValidDomain) {
|
||||
return callback(null, true);
|
||||
}
|
||||
|
||||
// Origin is not valid
|
||||
return callback(null, false);
|
||||
},
|
||||
methods: corsConfig?.methods,
|
||||
allowedHeaders: corsConfig?.allowed_headers,
|
||||
credentials: corsConfig?.credentials !== false
|
||||
} as CorsOptions;
|
||||
|
||||
return cors(originValidatedCorsConfig)(req, res, next);
|
||||
};
|
||||
}
|
||||
18
server/middlewares/private/index.ts
Normal file
18
server/middlewares/private/index.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./verifyCertificateAccess";
|
||||
export * from "./verifyRemoteExitNodeAccess";
|
||||
export * from "./verifyIdpAccess";
|
||||
export * from "./verifyLoginPageAccess";
|
||||
export * from "./corsWithLoginPage";
|
||||
126
server/middlewares/private/verifyCertificateAccess.ts
Normal file
126
server/middlewares/private/verifyCertificateAccess.ts
Normal file
@@ -0,0 +1,126 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { db, domainNamespaces } from "@server/db";
|
||||
import { certificates } from "@server/db";
|
||||
import { domains, orgDomains } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export async function verifyCertificateAccess(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) {
|
||||
try {
|
||||
// Assume user/org access is already verified
|
||||
const orgId = req.params.orgId;
|
||||
const certId = req.params.certId || req.body?.certId || req.query?.certId;
|
||||
let domainId =
|
||||
req.params.domainId || req.body?.domainId || req.query?.domainId;
|
||||
|
||||
if (!orgId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Invalid organization ID")
|
||||
);
|
||||
}
|
||||
|
||||
if (!domainId) {
|
||||
|
||||
if (!certId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Must provide either certId or domainId")
|
||||
);
|
||||
}
|
||||
|
||||
// Get the certificate and its domainId
|
||||
const [cert] = await db
|
||||
.select()
|
||||
.from(certificates)
|
||||
.where(eq(certificates.certId, Number(certId)))
|
||||
.limit(1);
|
||||
|
||||
if (!cert) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Certificate with ID ${certId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
domainId = cert.domainId;
|
||||
if (!domainId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Certificate with ID ${certId} does not have a domain`
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (!domainId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Must provide either certId or domainId")
|
||||
);
|
||||
}
|
||||
|
||||
// Check if the domain is a namespace domain
|
||||
const [namespaceDomain] = await db
|
||||
.select()
|
||||
.from(domainNamespaces)
|
||||
.where(eq(domainNamespaces.domainId, domainId))
|
||||
.limit(1);
|
||||
|
||||
if (namespaceDomain) {
|
||||
// If it's a namespace domain, we can skip the org check
|
||||
return next();
|
||||
}
|
||||
|
||||
// Check if the domain is associated with the org
|
||||
const [orgDomain] = await db
|
||||
.select()
|
||||
.from(orgDomains)
|
||||
.where(
|
||||
and(
|
||||
eq(orgDomains.orgId, orgId),
|
||||
eq(orgDomains.domainId, domainId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!orgDomain) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Organization does not have access to this certificate"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return next();
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Error verifying certificate access"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
export default verifyCertificateAccess;
|
||||
102
server/middlewares/private/verifyIdpAccess.ts
Normal file
102
server/middlewares/private/verifyIdpAccess.ts
Normal file
@@ -0,0 +1,102 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { userOrgs, db, idp, idpOrg } from "@server/db";
|
||||
import { and, eq, or } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
|
||||
export async function verifyIdpAccess(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) {
|
||||
try {
|
||||
const userId = req.user!.userId;
|
||||
const idpId =
|
||||
req.params.idpId || req.body.idpId || req.query.idpId;
|
||||
const orgId = req.params.orgId;
|
||||
|
||||
if (!userId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
|
||||
);
|
||||
}
|
||||
|
||||
if (!orgId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Invalid organization ID")
|
||||
);
|
||||
}
|
||||
|
||||
if (!idpId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Invalid key ID")
|
||||
);
|
||||
}
|
||||
|
||||
const [idpRes] = await db
|
||||
.select()
|
||||
.from(idp)
|
||||
.innerJoin(idpOrg, eq(idp.idpId, idpOrg.idpId))
|
||||
.where(
|
||||
and(eq(idp.idpId, idpId), eq(idpOrg.orgId, orgId))
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!idpRes || !idpRes.idp || !idpRes.idpOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`IdP with ID ${idpId} not found for organization ${orgId}`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!req.userOrg) {
|
||||
const userOrgRole = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.userId, userId),
|
||||
eq(userOrgs.orgId, idpRes.idpOrg.orgId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
req.userOrg = userOrgRole[0];
|
||||
}
|
||||
|
||||
if (!req.userOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not have access to this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const userOrgRoleId = req.userOrg.roleId;
|
||||
req.userOrgRoleId = userOrgRoleId;
|
||||
|
||||
return next();
|
||||
} catch (error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Error verifying idp access"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
81
server/middlewares/private/verifyLoginPageAccess.ts
Normal file
81
server/middlewares/private/verifyLoginPageAccess.ts
Normal file
@@ -0,0 +1,81 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { userOrgs, db, loginPageOrg } from "@server/db";
|
||||
import { and, eq, inArray } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
|
||||
export async function verifyLoginPageAccess(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) {
|
||||
try {
|
||||
const userId = req.user!.userId;
|
||||
const loginPageId =
|
||||
req.params.loginPageId ||
|
||||
req.body.loginPageId ||
|
||||
req.query.loginPageId;
|
||||
|
||||
if (!userId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
|
||||
);
|
||||
}
|
||||
|
||||
if (!loginPageId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Invalid login page ID")
|
||||
);
|
||||
}
|
||||
|
||||
const loginPageOrgs = await db
|
||||
.select({
|
||||
orgId: loginPageOrg.orgId
|
||||
})
|
||||
.from(loginPageOrg)
|
||||
.where(eq(loginPageOrg.loginPageId, loginPageId));
|
||||
|
||||
const orgIds = loginPageOrgs.map((lpo) => lpo.orgId);
|
||||
|
||||
const existingUserOrgs = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.userId, userId),
|
||||
inArray(userOrgs.orgId, orgIds)
|
||||
)
|
||||
);
|
||||
|
||||
if (existingUserOrgs.length === 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Login page with ID ${loginPageId} not found for user's organizations`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return next();
|
||||
} catch (error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Error verifying login page access"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
56
server/middlewares/private/verifyRemoteExitNode.ts
Normal file
56
server/middlewares/private/verifyRemoteExitNode.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { NextFunction, Response } from "express";
|
||||
import ErrorResponse from "@server/types/ErrorResponse";
|
||||
import config from "@server/lib/config";
|
||||
import { unauthorized } from "@server/auth/unauthorizedResponse";
|
||||
import logger from "@server/logger";
|
||||
import { validateRemoteExitNodeSessionToken } from "@server/auth/sessions/privateRemoteExitNode";
|
||||
|
||||
export const verifySessionRemoteExitNodeMiddleware = async (
|
||||
req: any,
|
||||
res: Response<ErrorResponse>,
|
||||
next: NextFunction
|
||||
) => {
|
||||
// get the token from the auth header
|
||||
const token = req.headers["authorization"]?.split(" ")[1] || "";
|
||||
|
||||
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
|
||||
|
||||
if (!session || !remoteExitNode) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(`Remote exit node session not found. IP: ${req.ip}.`);
|
||||
}
|
||||
return next(unauthorized());
|
||||
}
|
||||
|
||||
// const existingUser = await db
|
||||
// .select()
|
||||
// .from(users)
|
||||
// .where(eq(users.userId, user.userId));
|
||||
|
||||
// if (!existingUser || !existingUser[0]) {
|
||||
// if (config.getRawConfig().app.log_failed_attempts) {
|
||||
// logger.info(`User session not found. IP: ${req.ip}.`);
|
||||
// }
|
||||
// return next(
|
||||
// createHttpError(HttpCode.BAD_REQUEST, "User does not exist")
|
||||
// );
|
||||
// }
|
||||
|
||||
req.session = session;
|
||||
req.remoteExitNode = remoteExitNode;
|
||||
|
||||
next();
|
||||
};
|
||||
118
server/middlewares/private/verifyRemoteExitNodeAccess.ts
Normal file
118
server/middlewares/private/verifyRemoteExitNodeAccess.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { db, exitNodeOrgs, exitNodes, remoteExitNodes } from "@server/db";
|
||||
import { sites, userOrgs, userSites, roleSites, roles } from "@server/db";
|
||||
import { and, eq, or } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
|
||||
export async function verifyRemoteExitNodeAccess(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) {
|
||||
const userId = req.user!.userId; // Assuming you have user information in the request
|
||||
const orgId = req.params.orgId;
|
||||
const remoteExitNodeId =
|
||||
req.params.remoteExitNodeId ||
|
||||
req.body.remoteExitNodeId ||
|
||||
req.query.remoteExitNodeId;
|
||||
|
||||
if (!userId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const [remoteExitNode] = await db
|
||||
.select()
|
||||
.from(remoteExitNodes)
|
||||
.where(and(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId)));
|
||||
|
||||
if (!remoteExitNode) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Remote exit node with ID ${remoteExitNodeId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!remoteExitNode.exitNodeId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
`Remote exit node with ID ${remoteExitNodeId} does not have an exit node ID`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [exitNodeOrg] = await db
|
||||
.select()
|
||||
.from(exitNodeOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(exitNodeOrgs.exitNodeId, remoteExitNode.exitNodeId),
|
||||
eq(exitNodeOrgs.orgId, orgId)
|
||||
)
|
||||
);
|
||||
|
||||
if (!exitNodeOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Remote exit node with ID ${remoteExitNodeId} not found in organization ${orgId}`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!req.userOrg) {
|
||||
// Get user's role ID in the organization
|
||||
const userOrgRole = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.userId, userId),
|
||||
eq(userOrgs.orgId, exitNodeOrg.orgId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
req.userOrg = userOrgRole[0];
|
||||
}
|
||||
|
||||
if (!req.userOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not have access to this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const userOrgRoleId = req.userOrg.roleId;
|
||||
req.userOrgRoleId = userOrgRoleId;
|
||||
|
||||
return next();
|
||||
} catch (error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Error verifying remote exit node access"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,7 @@ export async function verifyOrgAccess(
|
||||
}
|
||||
|
||||
if (!req.userOrg) {
|
||||
next(
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not have access to this organization"
|
||||
|
||||
@@ -6,7 +6,7 @@ import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { User, users } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import {
|
||||
hashPassword,
|
||||
verifyPassword
|
||||
|
||||
@@ -3,7 +3,7 @@ import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import { validateResourceSessionToken } from "@server/auth/sessions/resource";
|
||||
import logger from "@server/logger";
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { User, users } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import { verifyPassword } from "@server/auth/password";
|
||||
import { verifyTotpCode } from "@server/auth/totp";
|
||||
import logger from "@server/logger";
|
||||
|
||||
@@ -10,7 +10,10 @@ export * from "./resetPassword";
|
||||
export * from "./requestPasswordReset";
|
||||
export * from "./setServerAdmin";
|
||||
export * from "./initialSetupComplete";
|
||||
export * from "./privateQuickStart";
|
||||
export * from "./validateSetupToken";
|
||||
export * from "./changePassword";
|
||||
export * from "./checkResourceSession";
|
||||
export * from "./securityKey";
|
||||
export * from "./privateGetSessionTransferToken";
|
||||
export * from "./privateTransferSession";
|
||||
|
||||
@@ -2,7 +2,7 @@ import { NextFunction, Request, Response } from "express";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import { db, users } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
|
||||
97
server/routers/auth/privateGetSessionTransferToken.ts
Normal file
97
server/routers/auth/privateGetSessionTransferToken.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, sessionTransferToken } from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import {
|
||||
generateSessionToken,
|
||||
SESSION_COOKIE_NAME
|
||||
} from "@server/auth/sessions/app";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { response } from "@server/lib/response";
|
||||
import { encrypt } from "@server/lib/crypto";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
const paramsSchema = z.object({}).strict();
|
||||
|
||||
export type GetSessionTransferTokenRenponse = {
|
||||
token: string;
|
||||
};
|
||||
|
||||
export async function getSessionTransferToken(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { user, session } = req;
|
||||
|
||||
if (!user || !session) {
|
||||
return next(createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized"));
|
||||
}
|
||||
|
||||
const tokenRaw = generateSessionToken();
|
||||
const token = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(tokenRaw))
|
||||
);
|
||||
|
||||
const rawSessionId = req.cookies[SESSION_COOKIE_NAME];
|
||||
|
||||
if (!rawSessionId) {
|
||||
return next(createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized"));
|
||||
}
|
||||
|
||||
const encryptedSession = encrypt(
|
||||
rawSessionId,
|
||||
config.getRawConfig().server.secret!
|
||||
);
|
||||
|
||||
await db.insert(sessionTransferToken).values({
|
||||
encryptedSession,
|
||||
token,
|
||||
sessionId: session.sessionId,
|
||||
expiresAt: Date.now() + 30 * 1000 // Token valid for 30 seconds
|
||||
});
|
||||
|
||||
return response<GetSessionTransferTokenRenponse>(res, {
|
||||
data: {
|
||||
token: tokenRaw
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Transfer token created successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
581
server/routers/auth/privateQuickStart.ts
Normal file
581
server/routers/auth/privateQuickStart.ts
Normal file
@@ -0,0 +1,581 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import {
|
||||
account,
|
||||
db,
|
||||
domainNamespaces,
|
||||
domains,
|
||||
exitNodes,
|
||||
newts,
|
||||
newtSessions,
|
||||
orgs,
|
||||
passwordResetTokens,
|
||||
Resource,
|
||||
resourcePassword,
|
||||
resourcePincode,
|
||||
resources,
|
||||
resourceWhitelist,
|
||||
roleResources,
|
||||
roles,
|
||||
roleSites,
|
||||
sites,
|
||||
targetHealthCheck,
|
||||
targets,
|
||||
userResources,
|
||||
userSites
|
||||
} from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { z } from "zod";
|
||||
import { users } from "@server/db";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
import { SqliteError } from "better-sqlite3";
|
||||
import { eq, and, sql } from "drizzle-orm";
|
||||
import moment from "moment";
|
||||
import { generateId } from "@server/auth/sessions/app";
|
||||
import config from "@server/lib/config";
|
||||
import logger from "@server/logger";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
import { createUserAccountOrg } from "@server/lib/private/createUserAccountOrg";
|
||||
import { sendEmail } from "@server/emails";
|
||||
import WelcomeQuickStart from "@server/emails/templates/WelcomeQuickStart";
|
||||
import { alphabet, generateRandomString } from "oslo/crypto";
|
||||
import { createDate, TimeSpan } from "oslo";
|
||||
import { getUniqueResourceName, getUniqueSiteName } from "@server/db/names";
|
||||
import { pickPort } from "../target/helpers";
|
||||
import { addTargets } from "../newt/targets";
|
||||
import { isTargetValid } from "@server/lib/validators";
|
||||
import { listExitNodes } from "@server/lib/exitNodes";
|
||||
|
||||
const bodySchema = z.object({
|
||||
email: z.string().toLowerCase().email(),
|
||||
ip: z.string().refine(isTargetValid),
|
||||
method: z.enum(["http", "https"]),
|
||||
port: z.number().int().min(1).max(65535),
|
||||
pincode: z
|
||||
.string()
|
||||
.regex(/^\d{6}$/)
|
||||
.optional(),
|
||||
password: z.string().min(4).max(100).optional(),
|
||||
enableWhitelist: z.boolean().optional().default(true),
|
||||
animalId: z.string() // This is actually the secret key for the backend
|
||||
});
|
||||
|
||||
export type QuickStartBody = z.infer<typeof bodySchema>;
|
||||
|
||||
export type QuickStartResponse = {
|
||||
newtId: string;
|
||||
newtSecret: string;
|
||||
resourceUrl: string;
|
||||
completeSignUpLink: string;
|
||||
};
|
||||
|
||||
const DEMO_UBO_KEY = "b460293f-347c-4b30-837d-4e06a04d5a22";
|
||||
|
||||
export async function quickStart(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const {
|
||||
email,
|
||||
ip,
|
||||
method,
|
||||
port,
|
||||
pincode,
|
||||
password,
|
||||
enableWhitelist,
|
||||
animalId
|
||||
} = parsedBody.data;
|
||||
|
||||
try {
|
||||
const tokenValidation = validateTokenOnApi(animalId);
|
||||
|
||||
if (!tokenValidation.isValid) {
|
||||
logger.warn(
|
||||
`Quick start failed for ${email} token ${animalId}: ${tokenValidation.message}`
|
||||
);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Invalid or expired token"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (animalId === DEMO_UBO_KEY) {
|
||||
if (email !== "mehrdad@getubo.com") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Invalid email for demo Ubo key"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [existing] = await db
|
||||
.select()
|
||||
.from(users)
|
||||
.where(
|
||||
and(
|
||||
eq(users.email, email),
|
||||
eq(users.type, UserType.Internal)
|
||||
)
|
||||
);
|
||||
|
||||
if (existing) {
|
||||
// delete the user if it already exists
|
||||
await db.delete(users).where(eq(users.userId, existing.userId));
|
||||
const orgId = `org_${existing.userId}`;
|
||||
await db.delete(orgs).where(eq(orgs.orgId, orgId));
|
||||
}
|
||||
}
|
||||
|
||||
const tempPassword = generateId(15);
|
||||
const passwordHash = await hashPassword(tempPassword);
|
||||
const userId = generateId(15);
|
||||
|
||||
// TODO: see if that user already exists?
|
||||
|
||||
// Create the sandbox user
|
||||
const existing = await db
|
||||
.select()
|
||||
.from(users)
|
||||
.where(
|
||||
and(eq(users.email, email), eq(users.type, UserType.Internal))
|
||||
);
|
||||
|
||||
if (existing && existing.length > 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"A user with that email address already exists"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let newtId: string;
|
||||
let secret: string;
|
||||
let fullDomain: string;
|
||||
let resource: Resource;
|
||||
let orgId: string;
|
||||
let completeSignUpLink: string;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
await trx.insert(users).values({
|
||||
userId: userId,
|
||||
type: UserType.Internal,
|
||||
username: email,
|
||||
email: email,
|
||||
passwordHash,
|
||||
dateCreated: moment().toISOString()
|
||||
});
|
||||
|
||||
// create user"s account
|
||||
await trx.insert(account).values({
|
||||
userId
|
||||
});
|
||||
});
|
||||
|
||||
const { success, error, org } = await createUserAccountOrg(
|
||||
userId,
|
||||
email
|
||||
);
|
||||
if (!success) {
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
throw new Error("Failed to create user account and organization");
|
||||
}
|
||||
if (!org) {
|
||||
throw new Error("Failed to create user account and organization");
|
||||
}
|
||||
|
||||
orgId = org.orgId;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
const token = generateRandomString(
|
||||
8,
|
||||
alphabet("0-9", "A-Z", "a-z")
|
||||
);
|
||||
|
||||
await trx
|
||||
.delete(passwordResetTokens)
|
||||
.where(eq(passwordResetTokens.userId, userId));
|
||||
|
||||
const tokenHash = await hashPassword(token);
|
||||
|
||||
await trx.insert(passwordResetTokens).values({
|
||||
userId: userId,
|
||||
email: email,
|
||||
tokenHash,
|
||||
expiresAt: createDate(new TimeSpan(7, "d")).getTime()
|
||||
});
|
||||
|
||||
// // Create the sandbox newt
|
||||
// const newClientAddress = await getNextAvailableClientSubnet(orgId);
|
||||
// if (!newClientAddress) {
|
||||
// throw new Error("No available subnet found");
|
||||
// }
|
||||
|
||||
// const clientAddress = newClientAddress.split("/")[0];
|
||||
|
||||
newtId = generateId(15);
|
||||
secret = generateId(48);
|
||||
|
||||
// Create the sandbox site
|
||||
const siteNiceId = await getUniqueSiteName(orgId);
|
||||
const siteName = `First Site`;
|
||||
let siteId: number | undefined;
|
||||
|
||||
// pick a random exit node
|
||||
const exitNodesList = await listExitNodes(orgId);
|
||||
|
||||
// select a random exit node
|
||||
const randomExitNode =
|
||||
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
|
||||
|
||||
if (!randomExitNode) {
|
||||
throw new Error("No exit nodes available");
|
||||
}
|
||||
|
||||
const [newSite] = await trx
|
||||
.insert(sites)
|
||||
.values({
|
||||
orgId,
|
||||
exitNodeId: randomExitNode.exitNodeId,
|
||||
name: siteName,
|
||||
niceId: siteNiceId,
|
||||
// address: clientAddress,
|
||||
type: "newt",
|
||||
dockerSocketEnabled: true
|
||||
})
|
||||
.returning();
|
||||
|
||||
siteId = newSite.siteId;
|
||||
|
||||
const adminRole = await trx
|
||||
.select()
|
||||
.from(roles)
|
||||
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
|
||||
.limit(1);
|
||||
|
||||
if (adminRole.length === 0) {
|
||||
throw new Error("Admin role not found");
|
||||
}
|
||||
|
||||
await trx.insert(roleSites).values({
|
||||
roleId: adminRole[0].roleId,
|
||||
siteId: newSite.siteId
|
||||
});
|
||||
|
||||
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
|
||||
// make sure the user can access the site
|
||||
await trx.insert(userSites).values({
|
||||
userId: req.user?.userId!,
|
||||
siteId: newSite.siteId
|
||||
});
|
||||
}
|
||||
|
||||
// add the peer to the exit node
|
||||
const secretHash = await hashPassword(secret!);
|
||||
|
||||
await trx.insert(newts).values({
|
||||
newtId: newtId!,
|
||||
secretHash,
|
||||
siteId: newSite.siteId,
|
||||
dateCreated: moment().toISOString()
|
||||
});
|
||||
|
||||
const [randomNamespace] = await trx
|
||||
.select()
|
||||
.from(domainNamespaces)
|
||||
.orderBy(sql`RANDOM()`)
|
||||
.limit(1);
|
||||
|
||||
if (!randomNamespace) {
|
||||
throw new Error("No domain namespace available");
|
||||
}
|
||||
|
||||
const [randomNamespaceDomain] = await trx
|
||||
.select()
|
||||
.from(domains)
|
||||
.where(eq(domains.domainId, randomNamespace.domainId))
|
||||
.limit(1);
|
||||
|
||||
if (!randomNamespaceDomain) {
|
||||
throw new Error("No domain found for the namespace");
|
||||
}
|
||||
|
||||
const resourceNiceId = await getUniqueResourceName(orgId);
|
||||
|
||||
// Create sandbox resource
|
||||
const subdomain = `${resourceNiceId}-${generateId(5)}`;
|
||||
fullDomain = `${subdomain}.${randomNamespaceDomain.baseDomain}`;
|
||||
|
||||
const resourceName = `First Resource`;
|
||||
|
||||
const newResource = await trx
|
||||
.insert(resources)
|
||||
.values({
|
||||
niceId: resourceNiceId,
|
||||
fullDomain,
|
||||
domainId: randomNamespaceDomain.domainId,
|
||||
orgId,
|
||||
name: resourceName,
|
||||
subdomain,
|
||||
http: true,
|
||||
protocol: "tcp",
|
||||
ssl: true,
|
||||
sso: false,
|
||||
emailWhitelistEnabled: enableWhitelist
|
||||
})
|
||||
.returning();
|
||||
|
||||
await trx.insert(roleResources).values({
|
||||
roleId: adminRole[0].roleId,
|
||||
resourceId: newResource[0].resourceId
|
||||
});
|
||||
|
||||
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
|
||||
// make sure the user can access the resource
|
||||
await trx.insert(userResources).values({
|
||||
userId: req.user?.userId!,
|
||||
resourceId: newResource[0].resourceId
|
||||
});
|
||||
}
|
||||
|
||||
resource = newResource[0];
|
||||
|
||||
// Create the sandbox target
|
||||
const { internalPort, targetIps } = await pickPort(siteId!, trx);
|
||||
|
||||
if (!internalPort) {
|
||||
throw new Error("No available internal port");
|
||||
}
|
||||
|
||||
const newTarget = await trx
|
||||
.insert(targets)
|
||||
.values({
|
||||
resourceId: resource.resourceId,
|
||||
siteId: siteId!,
|
||||
internalPort,
|
||||
ip,
|
||||
method,
|
||||
port,
|
||||
enabled: true
|
||||
})
|
||||
.returning();
|
||||
|
||||
const newHealthcheck = await trx
|
||||
.insert(targetHealthCheck)
|
||||
.values({
|
||||
targetId: newTarget[0].targetId,
|
||||
hcEnabled: false
|
||||
}).returning();
|
||||
|
||||
// add the new target to the targetIps array
|
||||
targetIps.push(`${ip}/32`);
|
||||
|
||||
const [newt] = await trx
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, siteId!))
|
||||
.limit(1);
|
||||
|
||||
await addTargets(newt.newtId, newTarget, newHealthcheck, resource.protocol);
|
||||
|
||||
// Set resource pincode if provided
|
||||
if (pincode) {
|
||||
await trx
|
||||
.delete(resourcePincode)
|
||||
.where(
|
||||
eq(resourcePincode.resourceId, resource!.resourceId)
|
||||
);
|
||||
|
||||
const pincodeHash = await hashPassword(pincode);
|
||||
|
||||
await trx.insert(resourcePincode).values({
|
||||
resourceId: resource!.resourceId,
|
||||
pincodeHash,
|
||||
digitLength: 6
|
||||
});
|
||||
}
|
||||
|
||||
// Set resource password if provided
|
||||
if (password) {
|
||||
await trx
|
||||
.delete(resourcePassword)
|
||||
.where(
|
||||
eq(resourcePassword.resourceId, resource!.resourceId)
|
||||
);
|
||||
|
||||
const passwordHash = await hashPassword(password);
|
||||
|
||||
await trx.insert(resourcePassword).values({
|
||||
resourceId: resource!.resourceId,
|
||||
passwordHash
|
||||
});
|
||||
}
|
||||
|
||||
// Set resource OTP if whitelist is enabled
|
||||
if (enableWhitelist) {
|
||||
await trx.insert(resourceWhitelist).values({
|
||||
email,
|
||||
resourceId: resource!.resourceId
|
||||
});
|
||||
}
|
||||
|
||||
completeSignUpLink = `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}&token=${token}`;
|
||||
|
||||
// Store token for email outside transaction
|
||||
await sendEmail(
|
||||
WelcomeQuickStart({
|
||||
username: email,
|
||||
link: completeSignUpLink,
|
||||
fallbackLink: `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}`,
|
||||
resourceMethod: method,
|
||||
resourceHostname: ip,
|
||||
resourcePort: port,
|
||||
resourceUrl: `https://${fullDomain}`,
|
||||
cliCommand: `newt --id ${newtId} --secret ${secret}`
|
||||
}),
|
||||
{
|
||||
to: email,
|
||||
from: config.getNoReplyEmail(),
|
||||
subject: `Access your Pangolin dashboard and resources`
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
return response<QuickStartResponse>(res, {
|
||||
data: {
|
||||
newtId: newtId!,
|
||||
newtSecret: secret!,
|
||||
resourceUrl: `https://${fullDomain!}`,
|
||||
completeSignUpLink: completeSignUpLink!
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Quick start completed successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
`Account already exists with that email. Email: ${email}. IP: ${req.ip}.`
|
||||
);
|
||||
}
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"A user with that email address already exists"
|
||||
)
|
||||
);
|
||||
} else {
|
||||
logger.error(e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to do quick start"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const BACKEND_SECRET_KEY = "4f9b6000-5d1a-11f0-9de7-ff2cc032f501";
|
||||
|
||||
/**
|
||||
* Validates a token received from the frontend.
|
||||
* @param {string} token The validation token from the request.
|
||||
* @returns {{ isValid: boolean; message: string }} An object indicating if the token is valid.
|
||||
*/
|
||||
const validateTokenOnApi = (
|
||||
token: string
|
||||
): { isValid: boolean; message: string } => {
|
||||
if (token === DEMO_UBO_KEY) {
|
||||
// Special case for demo UBO key
|
||||
return { isValid: true, message: "Demo UBO key is valid." };
|
||||
}
|
||||
|
||||
if (!token) {
|
||||
return { isValid: false, message: "Error: No token provided." };
|
||||
}
|
||||
|
||||
try {
|
||||
// 1. Decode the base64 string
|
||||
const decodedB64 = atob(token);
|
||||
|
||||
// 2. Reverse the character code manipulation
|
||||
const deobfuscated = decodedB64
|
||||
.split("")
|
||||
.map((char) => String.fromCharCode(char.charCodeAt(0) - 5)) // Reverse the shift
|
||||
.join("");
|
||||
|
||||
// 3. Split the data to get the original secret and timestamp
|
||||
const parts = deobfuscated.split("|");
|
||||
if (parts.length !== 2) {
|
||||
throw new Error("Invalid token format.");
|
||||
}
|
||||
const receivedKey = parts[0];
|
||||
const tokenTimestamp = parseInt(parts[1], 10);
|
||||
|
||||
// 4. Check if the secret key matches
|
||||
if (receivedKey !== BACKEND_SECRET_KEY) {
|
||||
return { isValid: false, message: "Invalid token: Key mismatch." };
|
||||
}
|
||||
|
||||
// 5. Check if the timestamp is recent (e.g., within 30 seconds) to prevent replay attacks
|
||||
const now = Date.now();
|
||||
const timeDifference = now - tokenTimestamp;
|
||||
|
||||
if (timeDifference > 30000) {
|
||||
// 30 seconds
|
||||
return { isValid: false, message: "Invalid token: Expired." };
|
||||
}
|
||||
|
||||
if (timeDifference < 0) {
|
||||
// Timestamp is in the future
|
||||
return {
|
||||
isValid: false,
|
||||
message: "Invalid token: Timestamp is in the future."
|
||||
};
|
||||
}
|
||||
|
||||
// If all checks pass, the token is valid
|
||||
return { isValid: true, message: "Token is valid!" };
|
||||
} catch (error) {
|
||||
// This will catch errors from atob (if not valid base64) or other issues.
|
||||
return {
|
||||
isValid: false,
|
||||
message: `Error: ${(error as Error).message}`
|
||||
};
|
||||
}
|
||||
};
|
||||
128
server/routers/auth/privateTransferSession.ts
Normal file
128
server/routers/auth/privateTransferSession.ts
Normal file
@@ -0,0 +1,128 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { sessions, sessionTransferToken } from "@server/db";
|
||||
import { db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { response } from "@server/lib/response";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { serializeSessionCookie } from "@server/auth/sessions/app";
|
||||
import { decrypt } from "@server/lib/crypto";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
const bodySchema = z.object({
|
||||
token: z.string()
|
||||
});
|
||||
|
||||
export type TransferSessionBodySchema = z.infer<typeof bodySchema>;
|
||||
|
||||
export type TransferSessionResponse = {
|
||||
valid: boolean;
|
||||
cookie?: string;
|
||||
};
|
||||
|
||||
export async function transferSession(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const { token } = parsedBody.data;
|
||||
|
||||
const tokenRaw = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(token))
|
||||
);
|
||||
|
||||
const [existing] = await db
|
||||
.select()
|
||||
.from(sessionTransferToken)
|
||||
.where(eq(sessionTransferToken.token, tokenRaw))
|
||||
.innerJoin(
|
||||
sessions,
|
||||
eq(sessions.sessionId, sessionTransferToken.sessionId)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!existing) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Invalid transfer token")
|
||||
);
|
||||
}
|
||||
|
||||
const transferToken = existing.sessionTransferToken;
|
||||
const session = existing.session;
|
||||
|
||||
if (!transferToken) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Invalid transfer token")
|
||||
);
|
||||
}
|
||||
|
||||
await db
|
||||
.delete(sessionTransferToken)
|
||||
.where(eq(sessionTransferToken.token, tokenRaw));
|
||||
|
||||
if (Date.now() > transferToken.expiresAt) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Transfer token expired")
|
||||
);
|
||||
}
|
||||
|
||||
const rawSession = decrypt(
|
||||
transferToken.encryptedSession,
|
||||
config.getRawConfig().server.secret!
|
||||
);
|
||||
|
||||
const isSecure = req.protocol === "https";
|
||||
const cookie = serializeSessionCookie(
|
||||
rawSession,
|
||||
isSecure,
|
||||
new Date(session.expiresAt)
|
||||
);
|
||||
res.appendHeader("Set-Cookie", cookie);
|
||||
|
||||
return response<TransferSessionResponse>(res, {
|
||||
data: { valid: true, cookie },
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Session exchanged successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to exchange session"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import { User } from "@server/db";
|
||||
import { sendEmailVerificationCode } from "../../auth/sendEmailVerificationCode";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
@@ -3,7 +3,7 @@ import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import { db } from "@server/db";
|
||||
import { passwordResetTokens, users } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
@@ -4,7 +4,7 @@ import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { encodeHex } from "oslo/encoding";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import { db } from "@server/db";
|
||||
import { User, users } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
@@ -113,7 +113,7 @@ export async function requestTotpSecret(
|
||||
const hex = crypto.getRandomValues(new Uint8Array(20));
|
||||
const secret = encodeHex(hex);
|
||||
const uri = createTOTPKeyURI(
|
||||
"Pangolin",
|
||||
config.getRawPrivateConfig().branding?.app_name || "Pangolin",
|
||||
user.email!,
|
||||
hex
|
||||
);
|
||||
|
||||
@@ -4,7 +4,7 @@ import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import { db } from "@server/db";
|
||||
import { passwordResetTokens, users } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
@@ -6,7 +6,7 @@ import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { User, securityKeys, users, webauthnChallenge } from "@server/db";
|
||||
import { eq, and, lt } from "drizzle-orm";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import logger from "@server/logger";
|
||||
import {
|
||||
generateRegistrationOptions,
|
||||
|
||||
@@ -7,7 +7,7 @@ import { generateId } from "@server/auth/sessions/app";
|
||||
import logger from "@server/logger";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import { passwordSchema } from "@server/auth/passwordSchema";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import { db, users, setupTokens } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
|
||||
@@ -21,7 +21,12 @@ import { hashPassword } from "@server/auth/password";
|
||||
import { checkValidInvite } from "@server/auth/checkValidInvite";
|
||||
import { passwordSchema } from "@server/auth/passwordSchema";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
import { createUserAccountOrg } from "@server/lib/private/createUserAccountOrg";
|
||||
import { build } from "@server/build";
|
||||
import resend, {
|
||||
AudienceIds,
|
||||
moveEmailToAudience
|
||||
} from "@server/lib/private/resend";
|
||||
|
||||
export const signupBodySchema = z.object({
|
||||
email: z.string().toLowerCase().email(),
|
||||
@@ -188,6 +193,26 @@ export async function signup(
|
||||
// orgId: null,
|
||||
// });
|
||||
|
||||
if (build == "saas") {
|
||||
const { success, error, org } = await createUserAccountOrg(
|
||||
userId,
|
||||
email
|
||||
);
|
||||
if (!success) {
|
||||
if (error) {
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, error)
|
||||
);
|
||||
}
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to create user account and organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const token = generateSessionToken();
|
||||
const sess = await createSession(token, userId);
|
||||
const isSecure = req.protocol === "https";
|
||||
@@ -198,6 +223,10 @@ export async function signup(
|
||||
);
|
||||
res.appendHeader("Set-Cookie", cookie);
|
||||
|
||||
if (build == "saas") {
|
||||
moveEmailToAudience(email, AudienceIds.General);
|
||||
}
|
||||
|
||||
if (config.getRawConfig().flags?.require_email_verification) {
|
||||
sendEmailVerificationCode(email, userId);
|
||||
|
||||
|
||||
@@ -3,13 +3,15 @@ import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import { db, userOrgs } from "@server/db";
|
||||
import { User, emailVerificationCodes, users } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { isWithinExpirationDate } from "oslo";
|
||||
import config from "@server/lib/config";
|
||||
import logger from "@server/logger";
|
||||
import { freeLimitSet, limitsService } from "@server/lib/private/billing";
|
||||
import { build } from "@server/build";
|
||||
|
||||
export const verifyEmailBody = z
|
||||
.object({
|
||||
@@ -88,6 +90,19 @@ export async function verifyEmail(
|
||||
);
|
||||
}
|
||||
|
||||
if (build == "saas") {
|
||||
const orgs = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(eq(userOrgs.userId, user.userId));
|
||||
const orgIds = orgs.map((org) => org.orgId);
|
||||
await Promise.all(
|
||||
orgIds.map(async (orgId) => {
|
||||
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
return response<VerifyEmailResponse>(res, {
|
||||
success: true,
|
||||
error: false,
|
||||
|
||||
@@ -3,7 +3,7 @@ import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
import { db } from "@server/db";
|
||||
import { twoFactorBackupCodes, User, users } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
import { generateSessionToken, SESSION_COOKIE_EXPIRES } from "@server/auth/sessions/app";
|
||||
import { SESSION_COOKIE_EXPIRES as RESOURCE_SESSION_COOKIE_EXPIRES } from "@server/auth/sessions/resource";
|
||||
import config from "@server/lib/config";
|
||||
import { response } from "@server/lib";
|
||||
import { response } from "@server/lib/response";
|
||||
|
||||
const exchangeSessionBodySchema = z.object({
|
||||
requestToken: z.string(),
|
||||
|
||||
@@ -11,9 +11,11 @@ import {
|
||||
getUserOrgRole,
|
||||
getRoleResourceAccess,
|
||||
getUserResourceAccess,
|
||||
getResourceRules
|
||||
getResourceRules,
|
||||
getOrgLoginPage
|
||||
} from "@server/db/queries/verifySessionQueries";
|
||||
import {
|
||||
LoginPage,
|
||||
Resource,
|
||||
ResourceAccessToken,
|
||||
ResourcePassword,
|
||||
@@ -32,7 +34,9 @@ import createHttpError from "http-errors";
|
||||
import NodeCache from "node-cache";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { getCountryCodeForIp } from "@server/lib";
|
||||
import { getCountryCodeForIp, remoteGetCountryCodeForIp } from "@server/lib/geoip";
|
||||
import { getOrgTierData } from "@server/routers/private/billing";
|
||||
import { TierId } from "@server/lib/private/billing/tiers";
|
||||
|
||||
// We'll see if this speeds anything up
|
||||
const cache = new NodeCache({
|
||||
@@ -196,16 +200,7 @@ export async function verifyResourceSession(
|
||||
return allowed(res);
|
||||
}
|
||||
|
||||
let endpoint: string;
|
||||
if (config.isManagedMode()) {
|
||||
endpoint =
|
||||
config.getRawConfig().managed?.redirect_endpoint ||
|
||||
config.getRawConfig().managed?.endpoint ||
|
||||
"";
|
||||
} else {
|
||||
endpoint = config.getRawConfig().app.dashboard_url!;
|
||||
}
|
||||
const redirectUrl = `${endpoint}/auth/resource/${encodeURIComponent(
|
||||
const redirectPath = `/auth/resource/${encodeURIComponent(
|
||||
resource.resourceGuid
|
||||
)}?redirect=${encodeURIComponent(originalRequestURL)}`;
|
||||
|
||||
@@ -408,7 +403,10 @@ export async function verifyResourceSession(
|
||||
}. IP: ${clientIp}.`
|
||||
);
|
||||
}
|
||||
return notAllowed(res, redirectUrl);
|
||||
|
||||
logger.debug(`Redirecting to login at ${redirectPath}`);
|
||||
|
||||
return notAllowed(res, redirectPath, resource.orgId);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
return next(
|
||||
@@ -463,7 +461,34 @@ function extractResourceSessionToken(
|
||||
return latest.token;
|
||||
}
|
||||
|
||||
function notAllowed(res: Response, redirectUrl?: string) {
|
||||
async function notAllowed(res: Response, redirectPath?: string, orgId?: string) {
|
||||
let loginPage: LoginPage | null = null;
|
||||
if (orgId) {
|
||||
const { tier } = await getOrgTierData(orgId); // returns null in oss
|
||||
if (tier === TierId.STANDARD) {
|
||||
loginPage = await getOrgLoginPage(orgId);
|
||||
}
|
||||
}
|
||||
|
||||
let redirectUrl: string | undefined = undefined;
|
||||
if (redirectPath) {
|
||||
let endpoint: string;
|
||||
|
||||
if (loginPage && loginPage.domainId && loginPage.fullDomain) {
|
||||
const secure = config.getRawConfig().app.dashboard_url?.startsWith("https");
|
||||
const method = secure ? "https" : "http";
|
||||
endpoint = `${method}://${loginPage.fullDomain}`;
|
||||
} else if (config.isManagedMode()) {
|
||||
endpoint =
|
||||
config.getRawConfig().managed?.redirect_endpoint ||
|
||||
config.getRawConfig().managed?.endpoint ||
|
||||
"";
|
||||
} else {
|
||||
endpoint = config.getRawConfig().app.dashboard_url!;
|
||||
}
|
||||
redirectUrl = `${endpoint}${redirectPath}`;
|
||||
}
|
||||
|
||||
const data = {
|
||||
data: { valid: false, redirectUrl },
|
||||
success: true,
|
||||
@@ -762,7 +787,11 @@ async function isIpInGeoIP(ip: string, countryCode: string): Promise<boolean> {
|
||||
let cachedCountryCode: string | undefined = cache.get(geoIpCacheKey);
|
||||
|
||||
if (!cachedCountryCode) {
|
||||
cachedCountryCode = await getCountryCodeForIp(ip);
|
||||
if (config.isManagedMode()) {
|
||||
cachedCountryCode = await remoteGetCountryCodeForIp(ip);
|
||||
} else {
|
||||
cachedCountryCode = await getCountryCodeForIp(ip); // do it locally
|
||||
}
|
||||
// Cache for longer since IP geolocation doesn't change frequently
|
||||
cache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import {
|
||||
addPeer as olmAddPeer,
|
||||
deletePeer as olmDeletePeer
|
||||
} from "../olm/peers";
|
||||
import { sendToExitNode } from "../../lib/exitNodeComms";
|
||||
import { sendToExitNode } from "@server/lib/exitNodes";
|
||||
|
||||
const updateClientParamsSchema = z
|
||||
.object({
|
||||
|
||||
@@ -9,7 +9,9 @@ import { fromError } from "zod-validation-error";
|
||||
import { subdomainSchema } from "@server/lib/schemas";
|
||||
import { generateId } from "@server/auth/sessions/app";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { isValidDomain } from "@server/lib/validators";
|
||||
import { usageService } from "@server/lib/private/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/private/billing";
|
||||
import { isSecondLevelDomain, isValidDomain } from "@server/lib/validators";
|
||||
import { build } from "@server/build";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
@@ -98,6 +100,45 @@ export async function createOrgDomain(
|
||||
);
|
||||
}
|
||||
|
||||
if (isSecondLevelDomain(baseDomain) && type == "cname") {
|
||||
// many providers dont allow cname for this. Lets prevent it for the user for now
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"You cannot create a CNAME record on a root domain. RFC 1912 § 2.4 prohibits CNAME records at the zone apex. Please use a subdomain."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (build == "saas") {
|
||||
const usage = await usageService.getUsage(orgId, FeatureId.DOMAINS);
|
||||
if (!usage) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"No usage data found for this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
const rejectDomains = await usageService.checkLimitSet(
|
||||
orgId,
|
||||
false,
|
||||
FeatureId.DOMAINS,
|
||||
{
|
||||
...usage,
|
||||
instantaneousValue: (usage.instantaneousValue || 0) + 1
|
||||
} // We need to add one to know if we are violating the limit
|
||||
);
|
||||
if (rejectDomains) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Domain limit exceeded. Please upgrade your plan."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let numOrgDomains: OrgDomains[] | undefined;
|
||||
let aRecords: CreateDomainResponse["aRecords"];
|
||||
let cnameRecords: CreateDomainResponse["cnameRecords"];
|
||||
@@ -260,6 +301,14 @@ export async function createOrgDomain(
|
||||
.where(eq(orgDomains.orgId, orgId));
|
||||
});
|
||||
|
||||
if (numOrgDomains) {
|
||||
await usageService.updateDaily(
|
||||
orgId,
|
||||
FeatureId.DOMAINS,
|
||||
numOrgDomains.length
|
||||
);
|
||||
}
|
||||
|
||||
if (!returned) {
|
||||
return next(
|
||||
createHttpError(
|
||||
|
||||
@@ -7,6 +7,8 @@ import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { usageService } from "@server/lib/private/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/private/billing";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
@@ -88,6 +90,14 @@ export async function deleteAccountDomain(
|
||||
.where(eq(orgDomains.orgId, orgId));
|
||||
});
|
||||
|
||||
if (numOrgDomains) {
|
||||
await usageService.updateDaily(
|
||||
orgId,
|
||||
FeatureId.DOMAINS,
|
||||
numOrgDomains.length
|
||||
);
|
||||
}
|
||||
|
||||
return response<DeleteAccountDomainResponse>(res, {
|
||||
data: { success: true },
|
||||
success: true,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
export * from "./listDomains";
|
||||
export * from "./createOrgDomain";
|
||||
export * from "./deleteOrgDomain";
|
||||
export * from "./privateListDomainNamespaces";
|
||||
export * from "./privateCheckDomainNamespaceAvailability";
|
||||
export * from "./restartOrgDomain";
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user