Merge branch 'dev' into feature/region-rules

This commit is contained in:
Owen
2026-03-30 14:35:42 -07:00
737 changed files with 67335 additions and 28651 deletions

View File

@@ -1,9 +1,10 @@
import { Request } from "express";
import { db } from "@server/db";
import { userActions, roleActions, userOrgs } from "@server/db";
import { and, eq } from "drizzle-orm";
import { userActions, roleActions } from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export enum ActionsEnum {
createOrgUser = "createOrgUser",
@@ -19,6 +20,7 @@ export enum ActionsEnum {
getSite = "getSite",
listSites = "listSites",
updateSite = "updateSite",
resetSiteBandwidth = "resetSiteBandwidth",
reGenerateSecret = "reGenerateSecret",
createResource = "createResource",
deleteResource = "deleteResource",
@@ -52,6 +54,8 @@ export enum ActionsEnum {
listRoleResources = "listRoleResources",
// listRoleActions = "listRoleActions",
addUserRole = "addUserRole",
removeUserRole = "removeUserRole",
setUserOrgRoles = "setUserOrgRoles",
// addUserSite = "addUserSite",
// addUserAction = "addUserAction",
// removeUserAction = "removeUserAction",
@@ -78,6 +82,10 @@ export enum ActionsEnum {
updateSiteResource = "updateSiteResource",
createClient = "createClient",
deleteClient = "deleteClient",
archiveClient = "archiveClient",
unarchiveClient = "unarchiveClient",
blockClient = "blockClient",
unblockClient = "unblockClient",
updateClient = "updateClient",
listClients = "listClients",
getClient = "getClient",
@@ -104,6 +112,10 @@ export enum ActionsEnum {
listApiKeyActions = "listApiKeyActions",
listApiKeys = "listApiKeys",
getApiKey = "getApiKey",
createSiteProvisioningKey = "createSiteProvisioningKey",
listSiteProvisioningKeys = "listSiteProvisioningKeys",
updateSiteProvisioningKey = "updateSiteProvisioningKey",
deleteSiteProvisioningKey = "deleteSiteProvisioningKey",
getCertificate = "getCertificate",
restartCertificate = "restartCertificate",
billing = "billing",
@@ -125,7 +137,10 @@ export enum ActionsEnum {
getBlueprint = "getBlueprint",
applyBlueprint = "applyBlueprint",
viewLogs = "viewLogs",
exportLogs = "exportLogs"
exportLogs = "exportLogs",
listApprovals = "listApprovals",
updateApprovals = "updateApprovals",
signSshKey = "signSshKey"
}
export async function checkUserActionPermission(
@@ -146,29 +161,16 @@ export async function checkUserActionPermission(
}
try {
let userOrgRoleId = req.userOrgRoleId;
let userOrgRoleIds = req.userOrgRoleIds;
// If userOrgRoleId is not available on the request, fetch it
if (userOrgRoleId === undefined) {
const userOrgRole = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
eq(userOrgs.orgId, req.userOrgId!)
)
)
.limit(1);
if (userOrgRole.length === 0) {
if (userOrgRoleIds === undefined) {
userOrgRoleIds = await getUserOrgRoleIds(userId, req.userOrgId!);
if (userOrgRoleIds.length === 0) {
throw createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
);
}
userOrgRoleId = userOrgRole[0].roleId;
}
// Check if the user has direct permission for the action in the current org
@@ -179,7 +181,7 @@ export async function checkUserActionPermission(
and(
eq(userActions.userId, userId),
eq(userActions.actionId, actionId),
eq(userActions.orgId, req.userOrgId!) // TODO: we cant pass the org id if we are not checking the org
eq(userActions.orgId, req.userOrgId!)
)
)
.limit(1);
@@ -188,14 +190,14 @@ export async function checkUserActionPermission(
return true;
}
// If no direct permission, check role-based permission
// If no direct permission, check role-based permission (any of user's roles)
const roleActionPermission = await db
.select()
.from(roleActions)
.where(
and(
eq(roleActions.actionId, actionId),
eq(roleActions.roleId, userOrgRoleId!),
inArray(roleActions.roleId, userOrgRoleIds),
eq(roleActions.orgId, req.userOrgId!)
)
)

View File

@@ -1,26 +1,29 @@
import { db } from "@server/db";
import { and, eq } from "drizzle-orm";
import { and, eq, inArray } from "drizzle-orm";
import { roleResources, userResources } from "@server/db";
export async function canUserAccessResource({
userId,
resourceId,
roleId
roleIds
}: {
userId: string;
resourceId: number;
roleId: number;
roleIds: number[];
}): Promise<boolean> {
const roleResourceAccess = await db
.select()
.from(roleResources)
.where(
and(
eq(roleResources.resourceId, resourceId),
eq(roleResources.roleId, roleId)
)
)
.limit(1);
const roleResourceAccess =
roleIds.length > 0
? await db
.select()
.from(roleResources)
.where(
and(
eq(roleResources.resourceId, resourceId),
inArray(roleResources.roleId, roleIds)
)
)
.limit(1)
: [];
if (roleResourceAccess.length > 0) {
return true;

View File

@@ -0,0 +1,48 @@
import { db } from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import { roleSiteResources, userSiteResources } from "@server/db";
export async function canUserAccessSiteResource({
userId,
resourceId,
roleIds
}: {
userId: string;
resourceId: number;
roleIds: number[];
}): Promise<boolean> {
const roleResourceAccess =
roleIds.length > 0
? await db
.select()
.from(roleSiteResources)
.where(
and(
eq(roleSiteResources.siteResourceId, resourceId),
inArray(roleSiteResources.roleId, roleIds)
)
)
.limit(1)
: [];
if (roleResourceAccess.length > 0) {
return true;
}
const userResourceAccess = await db
.select()
.from(userSiteResources)
.where(
and(
eq(userSiteResources.userId, userId),
eq(userSiteResources.siteResourceId, resourceId)
)
)
.limit(1);
if (userResourceAccess.length > 0) {
return true;
}
return false;
}

View File

@@ -3,7 +3,14 @@ import {
encodeHexLowerCase
} from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { resourceSessions, Session, sessions, User, users } from "@server/db";
import {
resourceSessions,
safeRead,
Session,
sessions,
User,
users
} from "@server/db";
import { db } from "@server/db";
import { eq, inArray } from "drizzle-orm";
import config from "@server/lib/config";
@@ -54,11 +61,15 @@ export async function validateSessionToken(
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token))
);
const result = await db
.select({ user: users, session: sessions })
.from(sessions)
.innerJoin(users, eq(sessions.userId, users.userId))
.where(eq(sessions.sessionId, sessionId));
const result = await safeRead((db) =>
db
.select({ user: users, session: sessions })
.from(sessions)
.innerJoin(users, eq(sessions.userId, users.userId))
.where(eq(sessions.sessionId, sessionId))
);
if (result.length < 1) {
return { session: null, user: null };
}

View File

@@ -1,7 +1,7 @@
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { resourceSessions, ResourceSession } from "@server/db";
import { db } from "@server/db";
import { db, safeRead } from "@server/db";
import { eq, and } from "drizzle-orm";
import config from "@server/lib/config";
@@ -66,15 +66,17 @@ export async function validateResourceSessionToken(
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token))
);
const result = await db
.select()
.from(resourceSessions)
.where(
and(
eq(resourceSessions.sessionId, sessionId),
eq(resourceSessions.resourceId, resourceId)
const result = await safeRead((db) =>
db
.select()
.from(resourceSessions)
.where(
and(
eq(resourceSessions.sessionId, sessionId),
eq(resourceSessions.resourceId, resourceId)
)
)
);
);
if (result.length < 1) {
return { resourceSession: null };
@@ -85,7 +87,7 @@ export async function validateResourceSessionToken(
if (Date.now() >= resourceSession.expiresAt) {
await db
.delete(resourceSessions)
.where(eq(resourceSessions.sessionId, resourceSessions.sessionId));
.where(eq(resourceSessions.sessionId, sessionId));
return { resourceSession: null };
} else if (
Date.now() >=
@@ -179,7 +181,7 @@ export function serializeResourceSessionCookie(
return `${cookieName}_s.${now}=${token}; HttpOnly; SameSite=Lax; Expires=${expiresAt.toUTCString()}; Path=/; Secure; Domain=${domain}`;
} else {
if (expiresAt === undefined) {
return `${cookieName}.${now}=${token}; HttpOnly; SameSite=Lax; Path=/; Domain=$domain}`;
return `${cookieName}.${now}=${token}; HttpOnly; SameSite=Lax; Path=/; Domain=${domain}`;
}
return `${cookieName}.${now}=${token}; HttpOnly; SameSite=Lax; Expires=${expiresAt.toUTCString()}; Path=/; Domain=${domain}`;
}

View File

@@ -1,6 +1,14 @@
import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage";
import { flushConnectionLogToDb } from "#dynamic/routers/newt";
import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth";
import { stopPingAccumulator } from "@server/routers/newt/pingAccumulator";
import { cleanup as wsCleanup } from "#dynamic/routers/ws";
async function cleanup() {
await stopPingAccumulator();
await flushBandwidthToDb();
await flushConnectionLogToDb();
await flushSiteBandwidthToDb();
await wsCleanup();
process.exit(0);

View File

@@ -56,15 +56,15 @@ Ensure drizzle-kit is installed.
You must have a connection string in your config file, as shown above.
```bash
npm run db:pg:generate
npm run db:pg:push
npm run db:generate
npm run db:push
```
### SQLite
```bash
npm run db:sqlite:generate
npm run db:sqlite:push
npm run db:generate
npm run db:push
```
## Build Time

View File

@@ -68,7 +68,7 @@ export const MAJOR_ASNS = [
code: "AS36351",
asn: 36351
},
// CDNs
{
name: "Cloudflare",
@@ -90,7 +90,7 @@ export const MAJOR_ASNS = [
code: "AS16625",
asn: 16625
},
// Mobile Carriers - US
{
name: "T-Mobile USA",
@@ -117,7 +117,7 @@ export const MAJOR_ASNS = [
code: "AS6430",
asn: 6430
},
// Mobile Carriers - Europe
{
name: "Vodafone UK",
@@ -144,7 +144,7 @@ export const MAJOR_ASNS = [
code: "AS12430",
asn: 12430
},
// Mobile Carriers - Asia
{
name: "NTT DoCoMo (Japan)",
@@ -176,7 +176,7 @@ export const MAJOR_ASNS = [
code: "AS9808",
asn: 9808
},
// Major US ISPs
{
name: "AT&T Services",
@@ -208,7 +208,7 @@ export const MAJOR_ASNS = [
code: "AS209",
asn: 209
},
// Major European ISPs
{
name: "Deutsche Telekom",
@@ -235,7 +235,7 @@ export const MAJOR_ASNS = [
code: "AS12956",
asn: 12956
},
// Major Asian ISPs
{
name: "China Telecom",
@@ -262,7 +262,7 @@ export const MAJOR_ASNS = [
code: "AS55836",
asn: 55836
},
// VPN/Proxy Providers
{
name: "Private Internet Access",
@@ -279,7 +279,7 @@ export const MAJOR_ASNS = [
code: "AS213281",
asn: 213281
},
// Social Media / Major Tech
{
name: "Facebook/Meta",
@@ -301,7 +301,7 @@ export const MAJOR_ASNS = [
code: "AS2906",
asn: 2906
},
// Academic/Research
{
name: "MIT",

150
server/db/ios_models.json Normal file
View File

@@ -0,0 +1,150 @@
{
"iPad1,1": "iPad",
"iPad2,1": "iPad 2",
"iPad2,2": "iPad 2",
"iPad2,3": "iPad 2",
"iPad2,4": "iPad 2",
"iPad3,1": "iPad 3rd Gen",
"iPad3,3": "iPad 3rd Gen",
"iPad3,2": "iPad 3rd Gen",
"iPad3,4": "iPad 4th Gen",
"iPad3,5": "iPad 4th Gen",
"iPad3,6": "iPad 4th Gen",
"iPad6,11": "iPad 9.7 5th Gen",
"iPad6,12": "iPad 9.7 5th Gen",
"iPad7,5": "iPad 9.7 6th Gen",
"iPad7,6": "iPad 9.7 6th Gen",
"iPad7,11": "iPad 10.2 7th Gen",
"iPad7,12": "iPad 10.2 7th Gen",
"iPad11,6": "iPad 10.2 8th Gen",
"iPad11,7": "iPad 10.2 8th Gen",
"iPad12,1": "iPad 10.2 9th Gen",
"iPad12,2": "iPad 10.2 9th Gen",
"iPad13,18": "iPad 10.9 10th Gen",
"iPad13,19": "iPad 10.9 10th Gen",
"iPad4,1": "iPad Air",
"iPad4,2": "iPad Air",
"iPad4,3": "iPad Air",
"iPad5,3": "iPad Air 2",
"iPad5,4": "iPad Air 2",
"iPad11,3": "iPad Air 3rd Gen",
"iPad11,4": "iPad Air 3rd Gen",
"iPad13,1": "iPad Air 4th Gen",
"iPad13,2": "iPad Air 4th Gen",
"iPad13,16": "iPad Air 5th Gen",
"iPad13,17": "iPad Air 5th Gen",
"iPad14,8": "iPad Air M2 11",
"iPad14,9": "iPad Air M2 11",
"iPad14,10": "iPad Air M2 13",
"iPad14,11": "iPad Air M2 13",
"iPad2,5": "iPad mini",
"iPad2,6": "iPad mini",
"iPad2,7": "iPad mini",
"iPad4,4": "iPad mini 2",
"iPad4,5": "iPad mini 2",
"iPad4,6": "iPad mini 2",
"iPad4,7": "iPad mini 3",
"iPad4,8": "iPad mini 3",
"iPad4,9": "iPad mini 3",
"iPad5,1": "iPad mini 4",
"iPad5,2": "iPad mini 4",
"iPad11,1": "iPad mini 5th Gen",
"iPad11,2": "iPad mini 5th Gen",
"iPad14,1": "iPad mini 6th Gen",
"iPad14,2": "iPad mini 6th Gen",
"iPad6,7": "iPad Pro 12.9",
"iPad6,8": "iPad Pro 12.9",
"iPad6,3": "iPad Pro 9.7",
"iPad6,4": "iPad Pro 9.7",
"iPad7,3": "iPad Pro 10.5",
"iPad7,4": "iPad Pro 10.5",
"iPad7,1": "iPad Pro 12.9",
"iPad7,2": "iPad Pro 12.9",
"iPad8,1": "iPad Pro 11",
"iPad8,2": "iPad Pro 11",
"iPad8,3": "iPad Pro 11",
"iPad8,4": "iPad Pro 11",
"iPad8,5": "iPad Pro 12.9",
"iPad8,6": "iPad Pro 12.9",
"iPad8,7": "iPad Pro 12.9",
"iPad8,8": "iPad Pro 12.9",
"iPad8,9": "iPad Pro 11",
"iPad8,10": "iPad Pro 11",
"iPad8,11": "iPad Pro 12.9",
"iPad8,12": "iPad Pro 12.9",
"iPad13,4": "iPad Pro 11",
"iPad13,5": "iPad Pro 11",
"iPad13,6": "iPad Pro 11",
"iPad13,7": "iPad Pro 11",
"iPad13,8": "iPad Pro 12.9",
"iPad13,9": "iPad Pro 12.9",
"iPad13,10": "iPad Pro 12.9",
"iPad13,11": "iPad Pro 12.9",
"iPad14,3": "iPad Pro 11",
"iPad14,4": "iPad Pro 11",
"iPad14,5": "iPad Pro 12.9",
"iPad14,6": "iPad Pro 12.9",
"iPad16,3": "iPad Pro M4 11",
"iPad16,4": "iPad Pro M4 11",
"iPad16,5": "iPad Pro M4 13",
"iPad16,6": "iPad Pro M4 13",
"iPhone1,1": "iPhone",
"iPhone1,2": "iPhone 3G",
"iPhone2,1": "iPhone 3GS",
"iPhone3,1": "iPhone 4",
"iPhone3,2": "iPhone 4",
"iPhone3,3": "iPhone 4",
"iPhone4,1": "iPhone 4S",
"iPhone5,1": "iPhone 5",
"iPhone5,2": "iPhone 5",
"iPhone5,3": "iPhone 5c",
"iPhone5,4": "iPhone 5c",
"iPhone6,1": "iPhone 5s",
"iPhone6,2": "iPhone 5s",
"iPhone7,2": "iPhone 6",
"iPhone7,1": "iPhone 6 Plus",
"iPhone8,1": "iPhone 6s",
"iPhone8,2": "iPhone 6s Plus",
"iPhone8,4": "iPhone SE",
"iPhone9,1": "iPhone 7",
"iPhone9,3": "iPhone 7",
"iPhone9,2": "iPhone 7 Plus",
"iPhone9,4": "iPhone 7 Plus",
"iPhone10,1": "iPhone 8",
"iPhone10,4": "iPhone 8",
"iPhone10,2": "iPhone 8 Plus",
"iPhone10,5": "iPhone 8 Plus",
"iPhone10,3": "iPhone X",
"iPhone10,6": "iPhone X",
"iPhone11,2": "iPhone Xs",
"iPhone11,6": "iPhone Xs Max",
"iPhone11,8": "iPhone XR",
"iPhone12,1": "iPhone 11",
"iPhone12,3": "iPhone 11 Pro",
"iPhone12,5": "iPhone 11 Pro Max",
"iPhone12,8": "iPhone SE",
"iPhone13,1": "iPhone 12 mini",
"iPhone13,2": "iPhone 12",
"iPhone13,3": "iPhone 12 Pro",
"iPhone13,4": "iPhone 12 Pro Max",
"iPhone14,4": "iPhone 13 mini",
"iPhone14,5": "iPhone 13",
"iPhone14,2": "iPhone 13 Pro",
"iPhone14,3": "iPhone 13 Pro Max",
"iPhone14,6": "iPhone SE",
"iPhone14,7": "iPhone 14",
"iPhone14,8": "iPhone 14 Plus",
"iPhone15,2": "iPhone 14 Pro",
"iPhone15,3": "iPhone 14 Pro Max",
"iPhone15,4": "iPhone 15",
"iPhone15,5": "iPhone 15 Plus",
"iPhone16,1": "iPhone 15 Pro",
"iPhone16,2": "iPhone 15 Pro Max",
"iPod1,1": "iPod touch Original",
"iPod2,1": "iPod touch 2nd",
"iPod3,1": "iPod touch 3rd Gen",
"iPod4,1": "iPod touch 4th",
"iPod5,1": "iPod touch 5th",
"iPod7,1": "iPod touch 6th Gen",
"iPod9,1": "iPod touch 7th Gen"
}

201
server/db/mac_models.json Normal file
View File

@@ -0,0 +1,201 @@
{
"PowerMac4,4": "eMac",
"PowerMac6,4": "eMac",
"PowerBook2,1": "iBook",
"PowerBook2,2": "iBook",
"PowerBook4,1": "iBook",
"PowerBook4,2": "iBook",
"PowerBook4,3": "iBook",
"PowerBook6,3": "iBook",
"PowerBook6,5": "iBook",
"PowerBook6,7": "iBook",
"iMac,1": "iMac",
"PowerMac2,1": "iMac",
"PowerMac2,2": "iMac",
"PowerMac4,1": "iMac",
"PowerMac4,2": "iMac",
"PowerMac4,5": "iMac",
"PowerMac6,1": "iMac",
"PowerMac6,3*": "iMac",
"PowerMac6,3": "iMac",
"PowerMac8,1": "iMac",
"PowerMac8,2": "iMac",
"PowerMac12,1": "iMac",
"iMac4,1": "iMac",
"iMac4,2": "iMac",
"iMac5,2": "iMac",
"iMac5,1": "iMac",
"iMac6,1": "iMac",
"iMac7,1": "iMac",
"iMac8,1": "iMac",
"iMac9,1": "iMac",
"iMac10,1": "iMac",
"iMac11,1": "iMac",
"iMac11,2": "iMac",
"iMac11,3": "iMac",
"iMac12,1": "iMac",
"iMac12,2": "iMac",
"iMac13,1": "iMac",
"iMac13,2": "iMac",
"iMac14,1": "iMac",
"iMac14,3": "iMac",
"iMac14,2": "iMac",
"iMac14,4": "iMac",
"iMac15,1": "iMac",
"iMac16,1": "iMac",
"iMac16,2": "iMac",
"iMac17,1": "iMac",
"iMac18,1": "iMac",
"iMac18,2": "iMac",
"iMac18,3": "iMac",
"iMac19,2": "iMac",
"iMac19,1": "iMac",
"iMac20,1": "iMac",
"iMac20,2": "iMac",
"iMac21,2": "iMac",
"iMac21,1": "iMac",
"iMacPro1,1": "iMac Pro",
"PowerMac10,1": "Mac mini",
"PowerMac10,2": "Mac mini",
"Macmini1,1": "Mac mini",
"Macmini2,1": "Mac mini",
"Macmini3,1": "Mac mini",
"Macmini4,1": "Mac mini",
"Macmini5,1": "Mac mini",
"Macmini5,2": "Mac mini",
"Macmini5,3": "Mac mini",
"Macmini6,1": "Mac mini",
"Macmini6,2": "Mac mini",
"Macmini7,1": "Mac mini",
"Macmini8,1": "Mac mini",
"ADP3,2": "Mac mini",
"Macmini9,1": "Mac mini",
"Mac14,3": "Mac mini",
"Mac14,12": "Mac mini",
"MacPro1,1*": "Mac Pro",
"MacPro2,1": "Mac Pro",
"MacPro3,1": "Mac Pro",
"MacPro4,1": "Mac Pro",
"MacPro5,1": "Mac Pro",
"MacPro6,1": "Mac Pro",
"MacPro7,1": "Mac Pro",
"N/A*": "Power Macintosh",
"PowerMac1,1": "Power Macintosh",
"PowerMac3,1": "Power Macintosh",
"PowerMac3,3": "Power Macintosh",
"PowerMac3,4": "Power Macintosh",
"PowerMac3,5": "Power Macintosh",
"PowerMac3,6": "Power Macintosh",
"Mac13,1": "Mac Studio",
"Mac13,2": "Mac Studio",
"MacBook1,1": "MacBook",
"MacBook2,1": "MacBook",
"MacBook3,1": "MacBook",
"MacBook4,1": "MacBook",
"MacBook5,1": "MacBook",
"MacBook5,2": "MacBook",
"MacBook6,1": "MacBook",
"MacBook7,1": "MacBook",
"MacBook8,1": "MacBook",
"MacBook9,1": "MacBook",
"MacBook10,1": "MacBook",
"MacBookAir1,1": "MacBook Air",
"MacBookAir2,1": "MacBook Air",
"MacBookAir3,1": "MacBook Air",
"MacBookAir3,2": "MacBook Air",
"MacBookAir4,1": "MacBook Air",
"MacBookAir4,2": "MacBook Air",
"MacBookAir5,1": "MacBook Air",
"MacBookAir5,2": "MacBook Air",
"MacBookAir6,1": "MacBook Air",
"MacBookAir6,2": "MacBook Air",
"MacBookAir7,1": "MacBook Air",
"MacBookAir7,2": "MacBook Air",
"MacBookAir8,1": "MacBook Air",
"MacBookAir8,2": "MacBook Air",
"MacBookAir9,1": "MacBook Air",
"MacBookAir10,1": "MacBook Air",
"Mac14,2": "MacBook Air",
"MacBookPro1,1": "MacBook Pro",
"MacBookPro1,2": "MacBook Pro",
"MacBookPro2,2": "MacBook Pro",
"MacBookPro2,1": "MacBook Pro",
"MacBookPro3,1": "MacBook Pro",
"MacBookPro4,1": "MacBook Pro",
"MacBookPro5,1": "MacBook Pro",
"MacBookPro5,2": "MacBook Pro",
"MacBookPro5,5": "MacBook Pro",
"MacBookPro5,4": "MacBook Pro",
"MacBookPro5,3": "MacBook Pro",
"MacBookPro7,1": "MacBook Pro",
"MacBookPro6,2": "MacBook Pro",
"MacBookPro6,1": "MacBook Pro",
"MacBookPro8,1": "MacBook Pro",
"MacBookPro8,2": "MacBook Pro",
"MacBookPro8,3": "MacBook Pro",
"MacBookPro9,2": "MacBook Pro",
"MacBookPro9,1": "MacBook Pro",
"MacBookPro10,1": "MacBook Pro",
"MacBookPro10,2": "MacBook Pro",
"MacBookPro11,1": "MacBook Pro",
"MacBookPro11,2": "MacBook Pro",
"MacBookPro11,3": "MacBook Pro",
"MacBookPro12,1": "MacBook Pro",
"MacBookPro11,4": "MacBook Pro",
"MacBookPro11,5": "MacBook Pro",
"MacBookPro13,1": "MacBook Pro",
"MacBookPro13,2": "MacBook Pro",
"MacBookPro13,3": "MacBook Pro",
"MacBookPro14,1": "MacBook Pro",
"MacBookPro14,2": "MacBook Pro",
"MacBookPro14,3": "MacBook Pro",
"MacBookPro15,2": "MacBook Pro",
"MacBookPro15,1": "MacBook Pro",
"MacBookPro15,3": "MacBook Pro",
"MacBookPro15,4": "MacBook Pro",
"MacBookPro16,1": "MacBook Pro",
"MacBookPro16,3": "MacBook Pro",
"MacBookPro16,2": "MacBook Pro",
"MacBookPro16,4": "MacBook Pro",
"MacBookPro17,1": "MacBook Pro",
"MacBookPro18,3": "MacBook Pro",
"MacBookPro18,4": "MacBook Pro",
"MacBookPro18,1": "MacBook Pro",
"MacBookPro18,2": "MacBook Pro",
"Mac14,7": "MacBook Pro",
"Mac14,9": "MacBook Pro",
"Mac14,5": "MacBook Pro",
"Mac14,10": "MacBook Pro",
"Mac14,6": "MacBook Pro",
"PowerMac1,2": "Power Macintosh",
"PowerMac5,1": "Power Macintosh",
"PowerMac7,2": "Power Macintosh",
"PowerMac7,3": "Power Macintosh",
"PowerMac9,1": "Power Macintosh",
"PowerMac11,2": "Power Macintosh",
"PowerBook1,1": "PowerBook",
"PowerBook3,1": "PowerBook",
"PowerBook3,2": "PowerBook",
"PowerBook3,3": "PowerBook",
"PowerBook3,4": "PowerBook",
"PowerBook3,5": "PowerBook",
"PowerBook6,1": "PowerBook",
"PowerBook5,1": "PowerBook",
"PowerBook6,2": "PowerBook",
"PowerBook5,2": "PowerBook",
"PowerBook5,3": "PowerBook",
"PowerBook6,4": "PowerBook",
"PowerBook5,4": "PowerBook",
"PowerBook5,5": "PowerBook",
"PowerBook6,8": "PowerBook",
"PowerBook5,6": "PowerBook",
"PowerBook5,7": "PowerBook",
"PowerBook5,8": "PowerBook",
"PowerBook5,9": "PowerBook",
"RackMac1,1": "Xserve",
"RackMac1,2": "Xserve",
"RackMac3,1": "Xserve",
"Xserve1,1": "Xserve",
"Xserve2,1": "Xserve",
"Xserve3,1": "Xserve"
}

3
server/db/migrate.ts Normal file
View File

@@ -0,0 +1,3 @@
import { runMigrations } from "./";
await runMigrations();

View File

@@ -16,6 +16,24 @@ if (!dev) {
}
export const names = JSON.parse(readFileSync(file, "utf-8"));
// Load iOS and Mac model mappings
let iosModelsFile: string;
let macModelsFile: string;
if (!dev) {
iosModelsFile = join(__DIRNAME, "ios_models.json");
macModelsFile = join(__DIRNAME, "mac_models.json");
} else {
iosModelsFile = join("server/db/ios_models.json");
macModelsFile = join("server/db/mac_models.json");
}
const iosModels: Record<string, string> = JSON.parse(
readFileSync(iosModelsFile, "utf-8")
);
const macModels: Record<string, string> = JSON.parse(
readFileSync(macModelsFile, "utf-8")
);
export async function getUniqueClientName(orgId: string): Promise<string> {
let loops = 0;
while (true) {
@@ -159,3 +177,29 @@ export function generateName(): string {
// clean out any non-alphanumeric characters except for dashes
return name.replace(/[^a-z0-9-]/g, "");
}
export function getMacDeviceName(macIdentifier?: string | null): string | null {
if (macIdentifier && macModels[macIdentifier]) {
return macModels[macIdentifier];
}
return null;
}
export function getIosDeviceName(iosIdentifier?: string | null): string | null {
if (iosIdentifier && iosModels[iosIdentifier]) {
return iosModels[iosIdentifier];
}
return null;
}
export function getUserDeviceName(
model: string | null,
fallBack: string | null
): string {
return (
getMacDeviceName(model) ||
getIosDeviceName(model) ||
fallBack ||
"Unknown Device"
);
}

View File

@@ -1,33 +1,33 @@
import { drizzle as DrizzlePostgres } from "drizzle-orm/node-postgres";
import { Pool } from "pg";
import { readConfigFile } from "@server/lib/readConfigFile";
import { withReplicas } from "drizzle-orm/pg-core";
import { createPool } from "./poolConfig";
function createDb() {
const config = readConfigFile();
if (!config.postgres) {
// check the environment variables for postgres config
if (process.env.POSTGRES_CONNECTION_STRING) {
config.postgres = {
connection_string: process.env.POSTGRES_CONNECTION_STRING
};
if (process.env.POSTGRES_REPLICA_CONNECTION_STRINGS) {
const replicas =
process.env.POSTGRES_REPLICA_CONNECTION_STRINGS.split(
","
).map((conn) => ({
// check the environment variables for postgres config first before the config file
if (process.env.POSTGRES_CONNECTION_STRING) {
config.postgres = {
connection_string: process.env.POSTGRES_CONNECTION_STRING
};
if (process.env.POSTGRES_REPLICA_CONNECTION_STRINGS) {
const replicas =
process.env.POSTGRES_REPLICA_CONNECTION_STRINGS.split(",").map(
(conn) => ({
connection_string: conn.trim()
}));
config.postgres.replicas = replicas;
}
} else {
throw new Error(
"Postgres configuration is missing in the configuration file."
);
})
);
config.postgres.replicas = replicas;
}
}
if (!config.postgres) {
throw new Error(
"Postgres configuration is missing in the configuration file."
);
}
const connectionString = config.postgres?.connection_string;
const replicaConnections = config.postgres?.replicas || [];
@@ -39,12 +39,17 @@ function createDb() {
// Create connection pools instead of individual connections
const poolConfig = config.postgres.pool;
const primaryPool = new Pool({
const maxConnections = poolConfig?.max_connections || 20;
const idleTimeoutMs = poolConfig?.idle_timeout_ms || 30000;
const connectionTimeoutMs = poolConfig?.connection_timeout_ms || 5000;
const primaryPool = createPool(
connectionString,
max: poolConfig?.max_connections || 20,
idleTimeoutMillis: poolConfig?.idle_timeout_ms || 30000,
connectionTimeoutMillis: poolConfig?.connection_timeout_ms || 5000
});
maxConnections,
idleTimeoutMs,
connectionTimeoutMs,
"primary"
);
const replicas = [];
@@ -55,14 +60,16 @@ function createDb() {
})
);
} else {
const maxReplicaConnections =
poolConfig?.max_replica_connections || 20;
for (const conn of replicaConnections) {
const replicaPool = new Pool({
connectionString: conn.connection_string,
max: poolConfig?.max_replica_connections || 20,
idleTimeoutMillis: poolConfig?.idle_timeout_ms || 30000,
connectionTimeoutMillis:
poolConfig?.connection_timeout_ms || 5000
});
const replicaPool = createPool(
conn.connection_string,
maxReplicaConnections,
idleTimeoutMs,
connectionTimeoutMs,
"replica"
);
replicas.push(
DrizzlePostgres(replicaPool, {
logger: process.env.QUERY_LOGGING == "true"
@@ -81,6 +88,7 @@ function createDb() {
export const db = createDb();
export default db;
export const primaryDb = db.$primary;
export type Transaction = Parameters<
Parameters<(typeof db)["transaction"]>[0]
>[0];
>[0];

View File

@@ -1,3 +1,6 @@
export * from "./driver";
export * from "./logsDriver";
export * from "./safeRead";
export * from "./schema/schema";
export * from "./schema/privateSchema";
export * from "./migrate";

View File

@@ -0,0 +1,94 @@
import { drizzle as DrizzlePostgres } from "drizzle-orm/node-postgres";
import { readConfigFile } from "@server/lib/readConfigFile";
import { withReplicas } from "drizzle-orm/pg-core";
import { build } from "@server/build";
import { db as mainDb, primaryDb as mainPrimaryDb } from "./driver";
import { createPool } from "./poolConfig";
function createLogsDb() {
// Only use separate logs database in SaaS builds
if (build !== "saas") {
return mainDb;
}
const config = readConfigFile();
// Merge configs, prioritizing private config
const logsConfig = config.postgres_logs;
// Check environment variable first
let connectionString = process.env.POSTGRES_LOGS_CONNECTION_STRING;
let replicaConnections: Array<{ connection_string: string }> = [];
if (!connectionString && logsConfig) {
connectionString = logsConfig.connection_string;
replicaConnections = logsConfig.replicas || [];
}
// If POSTGRES_LOGS_REPLICA_CONNECTION_STRINGS is set, use it
if (process.env.POSTGRES_LOGS_REPLICA_CONNECTION_STRINGS) {
replicaConnections =
process.env.POSTGRES_LOGS_REPLICA_CONNECTION_STRINGS.split(",").map(
(conn) => ({
connection_string: conn.trim()
})
);
}
// If no logs database is configured, fall back to main database
if (!connectionString) {
return mainDb;
}
// Create separate connection pool for logs database
const poolConfig = logsConfig?.pool || config.postgres?.pool;
const maxConnections = poolConfig?.max_connections || 20;
const idleTimeoutMs = poolConfig?.idle_timeout_ms || 30000;
const connectionTimeoutMs = poolConfig?.connection_timeout_ms || 5000;
const primaryPool = createPool(
connectionString,
maxConnections,
idleTimeoutMs,
connectionTimeoutMs,
"logs-primary"
);
const replicas = [];
if (!replicaConnections.length) {
replicas.push(
DrizzlePostgres(primaryPool, {
logger: process.env.QUERY_LOGGING == "true"
})
);
} else {
const maxReplicaConnections =
poolConfig?.max_replica_connections || 20;
for (const conn of replicaConnections) {
const replicaPool = createPool(
conn.connection_string,
maxReplicaConnections,
idleTimeoutMs,
connectionTimeoutMs,
"logs-replica"
);
replicas.push(
DrizzlePostgres(replicaPool, {
logger: process.env.QUERY_LOGGING == "true"
})
);
}
}
return withReplicas(
DrizzlePostgres(primaryPool, {
logger: process.env.QUERY_LOGGING == "true"
}),
replicas as any
);
}
export const logsDb = createLogsDb();
export default logsDb;
export const primaryLogsDb = logsDb.$primary;

View File

@@ -4,18 +4,16 @@ import path from "path";
const migrationsFolder = path.join("server/migrations");
const runMigrations = async () => {
export const runMigrations = async () => {
console.log("Running migrations...");
try {
await migrate(db as any, {
migrationsFolder: migrationsFolder
});
console.log("Migrations completed successfully.");
console.log("Migrations completed successfully.");
process.exit(0);
} catch (error) {
console.error("Error running migrations:", error);
process.exit(1);
}
};
runMigrations();

View File

@@ -0,0 +1,63 @@
import { Pool, PoolConfig } from "pg";
import logger from "@server/logger";
export function createPoolConfig(
connectionString: string,
maxConnections: number,
idleTimeoutMs: number,
connectionTimeoutMs: number
): PoolConfig {
return {
connectionString,
max: maxConnections,
idleTimeoutMillis: idleTimeoutMs,
connectionTimeoutMillis: connectionTimeoutMs,
// TCP keepalive to prevent silent connection drops by NAT gateways,
// load balancers, and other intermediate network devices (e.g. AWS
// NAT Gateway drops idle TCP connections after ~350s)
keepAlive: true,
keepAliveInitialDelayMillis: 10000, // send first keepalive after 10s of idle
// Allow connections to be released and recreated more aggressively
// to avoid stale connections building up
allowExitOnIdle: false
};
}
export function attachPoolErrorHandlers(pool: Pool, label: string): void {
pool.on("error", (err) => {
// This catches errors on idle clients in the pool. Without this
// handler an unexpected disconnect would crash the process.
logger.error(
`Unexpected error on idle ${label} database client: ${err.message}`
);
});
pool.on("connect", (client) => {
// Set a statement timeout on every new connection so a single slow
// query can't block the pool forever
client.query("SET statement_timeout = '30s'").catch((err: Error) => {
logger.warn(
`Failed to set statement_timeout on ${label} client: ${err.message}`
);
});
});
}
export function createPool(
connectionString: string,
maxConnections: number,
idleTimeoutMs: number,
connectionTimeoutMs: number,
label: string
): Pool {
const pool = new Pool(
createPoolConfig(
connectionString,
maxConnections,
idleTimeoutMs,
connectionTimeoutMs
)
);
attachPoolErrorHandlers(pool, label);
return pool;
}

24
server/db/pg/safeRead.ts Normal file
View File

@@ -0,0 +1,24 @@
import { db, primaryDb } from "./driver";
/**
* Runs a read query with replica fallback for Postgres.
* Executes the query against the replica first (when replicas exist).
* If the query throws or returns no data (null, undefined, or empty array),
* runs the same query against the primary.
*/
export async function safeRead<T>(
query: (d: typeof db | typeof primaryDb) => Promise<T>
): Promise<T> {
try {
const result = await query(db);
if (result === undefined || result === null) {
return query(primaryDb);
}
if (Array.isArray(result) && result.length === 0) {
return query(primaryDb);
}
return result;
} catch {
return query(primaryDb);
}
}

View File

@@ -7,10 +7,21 @@ import {
bigint,
real,
text,
index
index,
primaryKey
} from "drizzle-orm/pg-core";
import { InferSelectModel } from "drizzle-orm";
import { domains, orgs, targets, users, exitNodes, sessions } from "./schema";
import {
domains,
orgs,
targets,
users,
exitNodes,
sessions,
clients,
siteResources,
sites
} from "./schema";
export const certificates = pgTable("certificates", {
certId: serial("certId").primaryKey(),
@@ -74,11 +85,16 @@ export const subscriptions = pgTable("subscriptions", {
canceledAt: bigint("canceledAt", { mode: "number" }),
createdAt: bigint("createdAt", { mode: "number" }).notNull(),
updatedAt: bigint("updatedAt", { mode: "number" }),
billingCycleAnchor: bigint("billingCycleAnchor", { mode: "number" })
version: integer("version"),
billingCycleAnchor: bigint("billingCycleAnchor", { mode: "number" }),
type: varchar("type", { length: 50 }) // tier1, tier2, tier3, or license
});
export const subscriptionItems = pgTable("subscriptionItems", {
subscriptionItemId: serial("subscriptionItemId").primaryKey(),
stripeSubscriptionItemId: varchar("stripeSubscriptionItemId", {
length: 255
}),
subscriptionId: varchar("subscriptionId", { length: 255 })
.notNull()
.references(() => subscriptions.subscriptionId, {
@@ -86,6 +102,7 @@ export const subscriptionItems = pgTable("subscriptionItems", {
}),
planId: varchar("planId", { length: 255 }).notNull(),
priceId: varchar("priceId", { length: 255 }),
featureId: varchar("featureId", { length: 255 }),
meterId: varchar("meterId", { length: 255 }),
unitAmount: real("unitAmount"),
tiers: text("tiers"),
@@ -128,6 +145,7 @@ export const limits = pgTable("limits", {
})
.notNull(),
value: real("value"),
override: boolean("override").default(false),
description: text("description")
});
@@ -204,6 +222,29 @@ export const loginPageOrg = pgTable("loginPageOrg", {
.references(() => orgs.orgId, { onDelete: "cascade" })
});
export const loginPageBranding = pgTable("loginPageBranding", {
loginPageBrandingId: serial("loginPageBrandingId").primaryKey(),
logoUrl: text("logoUrl"),
logoWidth: integer("logoWidth").notNull(),
logoHeight: integer("logoHeight").notNull(),
primaryColor: text("primaryColor"),
resourceTitle: text("resourceTitle").notNull(),
resourceSubtitle: text("resourceSubtitle"),
orgTitle: text("orgTitle"),
orgSubtitle: text("orgSubtitle")
});
export const loginPageBrandingOrg = pgTable("loginPageBrandingOrg", {
loginPageBrandingId: integer("loginPageBrandingId")
.notNull()
.references(() => loginPageBranding.loginPageBrandingId, {
onDelete: "cascade"
}),
orgId: varchar("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" })
});
export const sessionTransferToken = pgTable("sessionTransferToken", {
token: varchar("token").primaryKey(),
sessionId: varchar("sessionId")
@@ -250,6 +291,7 @@ export const accessAuditLog = pgTable(
actor: varchar("actor", { length: 255 }),
actorId: varchar("actorId", { length: 255 }),
resourceId: integer("resourceId"),
siteResourceId: integer("siteResourceId"),
ip: varchar("ip", { length: 45 }),
type: varchar("type", { length: 100 }).notNull(),
action: boolean("action").notNull(),
@@ -266,6 +308,115 @@ export const accessAuditLog = pgTable(
]
);
export const connectionAuditLog = pgTable(
"connectionAuditLog",
{
id: serial("id").primaryKey(),
sessionId: text("sessionId").notNull(),
siteResourceId: integer("siteResourceId").references(
() => siteResources.siteResourceId,
{ onDelete: "cascade" }
),
orgId: text("orgId").references(() => orgs.orgId, {
onDelete: "cascade"
}),
siteId: integer("siteId").references(() => sites.siteId, {
onDelete: "cascade"
}),
clientId: integer("clientId").references(() => clients.clientId, {
onDelete: "cascade"
}),
userId: text("userId").references(() => users.userId, {
onDelete: "cascade"
}),
sourceAddr: text("sourceAddr").notNull(),
destAddr: text("destAddr").notNull(),
protocol: text("protocol").notNull(),
startedAt: integer("startedAt").notNull(),
endedAt: integer("endedAt"),
bytesTx: integer("bytesTx"),
bytesRx: integer("bytesRx")
},
(table) => [
index("idx_accessAuditLog_startedAt").on(table.startedAt),
index("idx_accessAuditLog_org_startedAt").on(
table.orgId,
table.startedAt
),
index("idx_accessAuditLog_siteResourceId").on(table.siteResourceId)
]
);
export const approvals = pgTable("approvals", {
approvalId: serial("approvalId").primaryKey(),
timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds
orgId: varchar("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
})
.notNull(),
clientId: integer("clientId").references(() => clients.clientId, {
onDelete: "cascade"
}), // clients reference user devices (in this case)
userId: varchar("userId")
.references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
})
.notNull(),
decision: varchar("decision")
.$type<"approved" | "denied" | "pending">()
.default("pending")
.notNull(),
type: varchar("type")
.$type<"user_device" /*| 'proxy' // for later */>()
.notNull()
});
export const bannedEmails = pgTable("bannedEmails", {
email: varchar("email", { length: 255 }).primaryKey()
});
export const bannedIps = pgTable("bannedIps", {
ip: varchar("ip", { length: 255 }).primaryKey()
});
export const siteProvisioningKeys = pgTable("siteProvisioningKeys", {
siteProvisioningKeyId: varchar("siteProvisioningKeyId", {
length: 255
}).primaryKey(),
name: varchar("name", { length: 255 }).notNull(),
siteProvisioningKeyHash: text("siteProvisioningKeyHash").notNull(),
lastChars: varchar("lastChars", { length: 4 }).notNull(),
createdAt: varchar("dateCreated", { length: 255 }).notNull(),
lastUsed: varchar("lastUsed", { length: 255 }),
maxBatchSize: integer("maxBatchSize"), // null = no limit
numUsed: integer("numUsed").notNull().default(0),
validUntil: varchar("validUntil", { length: 255 })
});
export const siteProvisioningKeyOrg = pgTable(
"siteProvisioningKeyOrg",
{
siteProvisioningKeyId: varchar("siteProvisioningKeyId", {
length: 255
})
.notNull()
.references(() => siteProvisioningKeys.siteProvisioningKeyId, {
onDelete: "cascade"
}),
orgId: varchar("orgId", { length: 255 })
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" })
},
(table) => [
primaryKey({
columns: [table.siteProvisioningKeyId, table.orgId]
})
]
);
export type Approval = InferSelectModel<typeof approvals>;
export type Limit = InferSelectModel<typeof limits>;
export type Account = InferSelectModel<typeof account>;
export type Certificate = InferSelectModel<typeof certificates>;
@@ -283,5 +434,7 @@ export type RemoteExitNodeSession = InferSelectModel<
>;
export type ExitNodeOrg = InferSelectModel<typeof exitNodeOrgs>;
export type LoginPage = InferSelectModel<typeof loginPage>;
export type LoginPageBranding = InferSelectModel<typeof loginPageBranding>;
export type ActionAuditLog = InferSelectModel<typeof actionAuditLog>;
export type AccessAuditLog = InferSelectModel<typeof accessAuditLog>;
export type ConnectionAuditLog = InferSelectModel<typeof connectionAuditLog>;

View File

@@ -1,17 +1,18 @@
import {
pgTable,
serial,
varchar,
boolean,
integer,
bigint,
real,
text,
index
} from "drizzle-orm/pg-core";
import { InferSelectModel } from "drizzle-orm";
import { randomUUID } from "crypto";
import { alias } from "yargs";
import { InferSelectModel } from "drizzle-orm";
import {
bigint,
boolean,
index,
integer,
pgTable,
primaryKey,
real,
serial,
text,
unique,
varchar
} from "drizzle-orm/pg-core";
export const domains = pgTable("domains", {
domainId: varchar("domainId").primaryKey(),
@@ -23,7 +24,8 @@ export const domains = pgTable("domains", {
tries: integer("tries").notNull().default(0),
certResolver: varchar("certResolver"),
customCertResolver: varchar("customCertResolver"),
preferWildcardCert: boolean("preferWildcardCert")
preferWildcardCert: boolean("preferWildcardCert"),
errorMessage: text("errorMessage")
});
export const dnsRecords = pgTable("dnsRecords", {
@@ -54,7 +56,14 @@ export const orgs = pgTable("orgs", {
.default(0),
settingsLogRetentionDaysAction: integer("settingsLogRetentionDaysAction") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year
.notNull()
.default(0)
.default(0),
settingsLogRetentionDaysConnection: integer("settingsLogRetentionDaysConnection") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year
.notNull()
.default(0),
sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format)
sshCaPublicKey: text("sshCaPublicKey"), // SSH CA public key (OpenSSH format)
isBillingOrg: boolean("isBillingOrg"),
billingOrgId: varchar("billingOrgId")
});
export const orgDomains = pgTable("orgDomains", {
@@ -85,6 +94,7 @@ export const sites = pgTable("sites", {
lastBandwidthUpdate: varchar("lastBandwidthUpdate"),
type: varchar("type").notNull(), // "newt" or "wireguard"
online: boolean("online").notNull().default(false),
lastPing: integer("lastPing"),
address: varchar("address"),
endpoint: varchar("endpoint"),
publicKey: varchar("publicKey"),
@@ -131,7 +141,18 @@ export const resources = pgTable("resources", {
}),
headers: text("headers"), // comma-separated list of headers to add to the request
proxyProtocol: boolean("proxyProtocol").notNull().default(false),
proxyProtocolVersion: integer("proxyProtocolVersion").default(1)
proxyProtocolVersion: integer("proxyProtocolVersion").default(1),
maintenanceModeEnabled: boolean("maintenanceModeEnabled")
.notNull()
.default(false),
maintenanceModeType: text("maintenanceModeType", {
enum: ["forced", "automatic"]
}).default("forced"), // "forced" = always show, "automatic" = only when down
maintenanceTitle: text("maintenanceTitle"),
maintenanceMessage: text("maintenanceMessage"),
maintenanceEstimatedTime: text("maintenanceEstimatedTime"),
postAuthPath: text("postAuthPath")
});
export const targets = pgTable("targets", {
@@ -176,7 +197,9 @@ export const targetHealthCheck = pgTable("targetHealthCheck", {
hcFollowRedirects: boolean("hcFollowRedirects").default(true),
hcMethod: varchar("hcMethod").default("GET"),
hcStatus: integer("hcStatus"), // http code
hcHealth: text("hcHealth").default("unknown"), // "unknown", "healthy", "unhealthy"
hcHealth: text("hcHealth")
.$type<"unknown" | "healthy" | "unhealthy">()
.default("unknown"), // "unknown", "healthy", "unhealthy"
hcTlsServerName: text("hcTlsServerName")
});
@@ -206,14 +229,21 @@ export const siteResources = pgTable("siteResources", {
.references(() => orgs.orgId, { onDelete: "cascade" }),
niceId: varchar("niceId").notNull(),
name: varchar("name").notNull(),
mode: varchar("mode").notNull(), // "host" | "cidr" | "port"
mode: varchar("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port"
protocol: varchar("protocol"), // only for port mode
proxyPort: integer("proxyPort"), // only for port mode
destinationPort: integer("destinationPort"), // only for port mode
destination: varchar("destination").notNull(), // ip, cidr, hostname; validate against the mode
enabled: boolean("enabled").notNull().default(true),
alias: varchar("alias"),
aliasAddress: varchar("aliasAddress")
aliasAddress: varchar("aliasAddress"),
tcpPortRangeString: varchar("tcpPortRangeString").notNull().default("*"),
udpPortRangeString: varchar("udpPortRangeString").notNull().default("*"),
disableIcmp: boolean("disableIcmp").notNull().default(false),
authDaemonPort: integer("authDaemonPort").default(22123),
authDaemonMode: varchar("authDaemonMode", { length: 32 })
.$type<"site" | "remote">()
.default("site")
});
export const clientSiteResources = pgTable("clientSiteResources", {
@@ -260,6 +290,7 @@ export const users = pgTable("user", {
dateCreated: varchar("dateCreated").notNull(),
termsAcceptedTimestamp: varchar("termsAcceptedTimestamp"),
termsVersion: varchar("termsVersion"),
marketingEmailConsent: boolean("marketingEmailConsent").default(false),
serverAdmin: boolean("serverAdmin").notNull().default(false),
lastPasswordChange: bigint("lastPasswordChange", { mode: "number" })
});
@@ -309,11 +340,9 @@ export const userOrgs = pgTable("userOrgs", {
onDelete: "cascade"
})
.notNull(),
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId),
isOwner: boolean("isOwner").notNull().default(false),
autoProvisioned: boolean("autoProvisioned").default(false)
autoProvisioned: boolean("autoProvisioned").default(false),
pamUsername: varchar("pamUsername") // cleaned username for ssh and such
});
export const emailVerificationCodes = pgTable("emailVerificationCodes", {
@@ -351,9 +380,30 @@ export const roles = pgTable("roles", {
.notNull(),
isAdmin: boolean("isAdmin"),
name: varchar("name").notNull(),
description: varchar("description")
description: varchar("description"),
requireDeviceApproval: boolean("requireDeviceApproval").default(false),
sshSudoMode: varchar("sshSudoMode", { length: 32 }).default("none"), // "none" | "full" | "commands"
sshSudoCommands: text("sshSudoCommands").default("[]"),
sshCreateHomeDir: boolean("sshCreateHomeDir").default(true),
sshUnixGroups: text("sshUnixGroups").default("[]")
});
export const userOrgRoles = pgTable(
"userOrgRoles",
{
userId: varchar("userId")
.notNull()
.references(() => users.userId, { onDelete: "cascade" }),
orgId: varchar("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" })
},
(t) => [unique().on(t.userId, t.orgId, t.roleId)]
);
export const roleActions = pgTable("roleActions", {
roleId: integer("roleId")
.notNull()
@@ -421,12 +471,22 @@ export const userInvites = pgTable("userInvites", {
.references(() => orgs.orgId, { onDelete: "cascade" }),
email: varchar("email").notNull(),
expiresAt: bigint("expiresAt", { mode: "number" }).notNull(),
tokenHash: varchar("token").notNull(),
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" })
tokenHash: varchar("token").notNull()
});
export const userInviteRoles = pgTable(
"userInviteRoles",
{
inviteId: varchar("inviteId")
.notNull()
.references(() => userInvites.inviteId, { onDelete: "cascade" }),
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" })
},
(t) => [primaryKey({ columns: [t.inviteId, t.roleId] })]
);
export const resourcePincode = pgTable("resourcePincode", {
pincodeId: serial("pincodeId").primaryKey(),
resourceId: integer("resourceId")
@@ -452,6 +512,23 @@ export const resourceHeaderAuth = pgTable("resourceHeaderAuth", {
headerAuthHash: varchar("headerAuthHash").notNull()
});
export const resourceHeaderAuthExtendedCompatibility = pgTable(
"resourceHeaderAuthExtendedCompatibility",
{
headerAuthExtendedCompatibilityId: serial(
"headerAuthExtendedCompatibilityId"
).primaryKey(),
resourceId: integer("resourceId")
.notNull()
.references(() => resources.resourceId, { onDelete: "cascade" }),
extendedCompatibilityIsActivated: boolean(
"extendedCompatibilityIsActivated"
)
.notNull()
.default(true)
}
);
export const resourceAccessToken = pgTable("resourceAccessToken", {
accessTokenId: varchar("accessTokenId").primaryKey(),
orgId: varchar("orgId")
@@ -560,7 +637,8 @@ export const idp = pgTable("idp", {
type: varchar("type").notNull(),
defaultRoleMapping: varchar("defaultRoleMapping"),
defaultOrgMapping: varchar("defaultOrgMapping"),
autoProvision: boolean("autoProvision").notNull().default(false)
autoProvision: boolean("autoProvision").notNull().default(false),
tags: text("tags")
});
export const idpOidcConfig = pgTable("idpOidcConfig", {
@@ -657,7 +735,12 @@ export const clients = pgTable("clients", {
online: boolean("online").notNull().default(false),
// endpoint: varchar("endpoint"),
lastHolePunch: integer("lastHolePunch"),
maxConnections: integer("maxConnections")
maxConnections: integer("maxConnections"),
archived: boolean("archived").notNull().default(false),
blocked: boolean("blocked").notNull().default(false),
approvalState: varchar("approvalState").$type<
"pending" | "approved" | "denied"
>()
});
export const clientSitesAssociationsCache = pgTable(
@@ -667,6 +750,7 @@ export const clientSitesAssociationsCache = pgTable(
.notNull(),
siteId: integer("siteId").notNull(),
isRelayed: boolean("isRelayed").notNull().default(false),
isJitMode: boolean("isJitMode").notNull().default(false),
endpoint: varchar("endpoint"),
publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
}
@@ -681,6 +765,16 @@ export const clientSiteResourcesAssociationsCache = pgTable(
}
);
export const clientPostureSnapshots = pgTable("clientPostureSnapshots", {
snapshotId: serial("snapshotId").primaryKey(),
clientId: integer("clientId").references(() => clients.clientId, {
onDelete: "cascade"
}),
collectedAt: integer("collectedAt").notNull()
});
export const olms = pgTable("olms", {
olmId: varchar("id").primaryKey(),
secretHash: varchar("secretHash").notNull(),
@@ -695,7 +789,118 @@ export const olms = pgTable("olms", {
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
})
}),
archived: boolean("archived").notNull().default(false)
});
export const currentFingerprint = pgTable("currentFingerprint", {
fingerprintId: serial("id").primaryKey(),
olmId: text("olmId")
.references(() => olms.olmId, { onDelete: "cascade" })
.notNull(),
firstSeen: integer("firstSeen").notNull(),
lastSeen: integer("lastSeen").notNull(),
lastCollectedAt: integer("lastCollectedAt").notNull(),
username: text("username"),
hostname: text("hostname"),
platform: text("platform"),
osVersion: text("osVersion"),
kernelVersion: text("kernelVersion"),
arch: text("arch"),
deviceModel: text("deviceModel"),
serialNumber: text("serialNumber"),
platformFingerprint: varchar("platformFingerprint"),
// Platform-agnostic checks
biometricsEnabled: boolean("biometricsEnabled").notNull().default(false),
diskEncrypted: boolean("diskEncrypted").notNull().default(false),
firewallEnabled: boolean("firewallEnabled").notNull().default(false),
autoUpdatesEnabled: boolean("autoUpdatesEnabled").notNull().default(false),
tpmAvailable: boolean("tpmAvailable").notNull().default(false),
// Windows-specific posture check information
windowsAntivirusEnabled: boolean("windowsAntivirusEnabled")
.notNull()
.default(false),
// macOS-specific posture check information
macosSipEnabled: boolean("macosSipEnabled").notNull().default(false),
macosGatekeeperEnabled: boolean("macosGatekeeperEnabled")
.notNull()
.default(false),
macosFirewallStealthMode: boolean("macosFirewallStealthMode")
.notNull()
.default(false),
// Linux-specific posture check information
linuxAppArmorEnabled: boolean("linuxAppArmorEnabled")
.notNull()
.default(false),
linuxSELinuxEnabled: boolean("linuxSELinuxEnabled").notNull().default(false)
});
export const fingerprintSnapshots = pgTable("fingerprintSnapshots", {
snapshotId: serial("id").primaryKey(),
fingerprintId: integer("fingerprintId").references(
() => currentFingerprint.fingerprintId,
{
onDelete: "set null"
}
),
username: text("username"),
hostname: text("hostname"),
platform: text("platform"),
osVersion: text("osVersion"),
kernelVersion: text("kernelVersion"),
arch: text("arch"),
deviceModel: text("deviceModel"),
serialNumber: text("serialNumber"),
platformFingerprint: varchar("platformFingerprint"),
// Platform-agnostic checks
biometricsEnabled: boolean("biometricsEnabled").notNull().default(false),
diskEncrypted: boolean("diskEncrypted").notNull().default(false),
firewallEnabled: boolean("firewallEnabled").notNull().default(false),
autoUpdatesEnabled: boolean("autoUpdatesEnabled").notNull().default(false),
tpmAvailable: boolean("tpmAvailable").notNull().default(false),
// Windows-specific posture check information
windowsAntivirusEnabled: boolean("windowsAntivirusEnabled")
.notNull()
.default(false),
// macOS-specific posture check information
macosSipEnabled: boolean("macosSipEnabled").notNull().default(false),
macosGatekeeperEnabled: boolean("macosGatekeeperEnabled")
.notNull()
.default(false),
macosFirewallStealthMode: boolean("macosFirewallStealthMode")
.notNull()
.default(false),
// Linux-specific posture check information
linuxAppArmorEnabled: boolean("linuxAppArmorEnabled")
.notNull()
.default(false),
linuxSELinuxEnabled: boolean("linuxSELinuxEnabled")
.notNull()
.default(false),
hash: text("hash").notNull(),
collectedAt: integer("collectedAt").notNull()
});
export const olmSessions = pgTable("clientSession", {
@@ -824,6 +1029,16 @@ export const deviceWebAuthCodes = pgTable("deviceWebAuthCodes", {
})
});
export const roundTripMessageTracker = pgTable("roundTripMessageTracker", {
messageId: serial("messageId").primaryKey(),
wsClientId: varchar("clientId"),
messageType: varchar("messageType"),
sentAt: bigint("sentAt", { mode: "number" }).notNull(),
receivedAt: bigint("receivedAt", { mode: "number" }),
error: text("error"),
complete: boolean("complete").notNull().default(false)
});
export type Org = InferSelectModel<typeof orgs>;
export type User = InferSelectModel<typeof users>;
export type Site = InferSelectModel<typeof sites>;
@@ -847,11 +1062,16 @@ export type UserSite = InferSelectModel<typeof userSites>;
export type RoleResource = InferSelectModel<typeof roleResources>;
export type UserResource = InferSelectModel<typeof userResources>;
export type UserInvite = InferSelectModel<typeof userInvites>;
export type UserInviteRole = InferSelectModel<typeof userInviteRoles>;
export type UserOrg = InferSelectModel<typeof userOrgs>;
export type UserOrgRole = InferSelectModel<typeof userOrgRoles>;
export type ResourceSession = InferSelectModel<typeof resourceSessions>;
export type ResourcePincode = InferSelectModel<typeof resourcePincode>;
export type ResourcePassword = InferSelectModel<typeof resourcePassword>;
export type ResourceHeaderAuth = InferSelectModel<typeof resourceHeaderAuth>;
export type ResourceHeaderAuthExtendedCompatibility = InferSelectModel<
typeof resourceHeaderAuthExtendedCompatibility
>;
export type ResourceOtp = InferSelectModel<typeof resourceOtp>;
export type ResourceAccessToken = InferSelectModel<typeof resourceAccessToken>;
export type ResourceWhitelist = InferSelectModel<typeof resourceWhitelist>;
@@ -881,3 +1101,6 @@ export type SecurityKey = InferSelectModel<typeof securityKeys>;
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
export type RoundTripMessageTracker = InferSelectModel<
typeof roundTripMessageTracker
>;

View File

@@ -1,4 +1,12 @@
import { db, loginPage, LoginPage, loginPageOrg, Org, orgs } from "@server/db";
import {
db,
loginPage,
LoginPage,
loginPageOrg,
Org,
orgs,
roles
} from "@server/db";
import {
Resource,
ResourcePassword,
@@ -12,17 +20,19 @@ import {
resources,
roleResources,
sessions,
userOrgs,
userResources,
users
users,
ResourceHeaderAuthExtendedCompatibility,
resourceHeaderAuthExtendedCompatibility
} from "@server/db";
import { and, eq } from "drizzle-orm";
import { and, eq, inArray } from "drizzle-orm";
export type ResourceWithAuth = {
resource: Resource | null;
pincode: ResourcePincode | null;
password: ResourcePassword | null;
headerAuth: ResourceHeaderAuth | null;
headerAuthExtendedCompatibility: ResourceHeaderAuthExtendedCompatibility | null;
org: Org;
};
@@ -52,6 +62,13 @@ export async function getResourceByDomain(
resourceHeaderAuth,
eq(resourceHeaderAuth.resourceId, resources.resourceId)
)
.leftJoin(
resourceHeaderAuthExtendedCompatibility,
eq(
resourceHeaderAuthExtendedCompatibility.resourceId,
resources.resourceId
)
)
.innerJoin(orgs, eq(orgs.orgId, resources.orgId))
.where(eq(resources.fullDomain, domain))
.limit(1);
@@ -65,6 +82,8 @@ export async function getResourceByDomain(
pincode: result.resourcePincode,
password: result.resourcePassword,
headerAuth: result.resourceHeaderAuth,
headerAuthExtendedCompatibility:
result.resourceHeaderAuthExtendedCompatibility,
org: result.orgs
};
}
@@ -92,16 +111,15 @@ export async function getUserSessionWithUser(
}
/**
* Get user organization role
* Get role name by role ID (for display).
*/
export async function getUserOrgRole(userId: string, orgId: string) {
const userOrgRole = await db
.select()
.from(userOrgs)
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)))
export async function getRoleName(roleId: number): Promise<string | null> {
const [row] = await db
.select({ name: roles.name })
.from(roles)
.where(eq(roles.roleId, roleId))
.limit(1);
return userOrgRole.length > 0 ? userOrgRole[0] : null;
return row?.name ?? null;
}
/**
@@ -109,7 +127,7 @@ export async function getUserOrgRole(userId: string, orgId: string) {
*/
export async function getRoleResourceAccess(
resourceId: number,
roleId: number
roleIds: number[]
) {
const roleResourceAccess = await db
.select()
@@ -117,12 +135,11 @@ export async function getRoleResourceAccess(
.where(
and(
eq(roleResources.resourceId, resourceId),
eq(roleResources.roleId, roleId)
inArray(roleResources.roleId, roleIds)
)
)
.limit(1);
);
return roleResourceAccess.length > 0 ? roleResourceAccess[0] : null;
return roleResourceAccess.length > 0 ? roleResourceAccess : null;
}
/**

View File

@@ -20,6 +20,7 @@ function createDb() {
export const db = createDb();
export default db;
export const primaryDb = db;
export type Transaction = Parameters<
Parameters<(typeof db)["transaction"]>[0]
>[0];

View File

@@ -1,3 +1,6 @@
export * from "./driver";
export * from "./logsDriver";
export * from "./safeRead";
export * from "./schema/schema";
export * from "./schema/privateSchema";
export * from "./migrate";

View File

@@ -0,0 +1,7 @@
import { db as mainDb } from "./driver";
// SQLite doesn't support separate databases for logs in the same way as Postgres
// Always use the main database connection for SQLite
export const logsDb = mainDb;
export default logsDb;
export const primaryLogsDb = logsDb;

View File

@@ -4,7 +4,7 @@ import path from "path";
const migrationsFolder = path.join("server/migrations");
const runMigrations = async () => {
export const runMigrations = async () => {
console.log("Running migrations...");
try {
migrate(db as any, {
@@ -16,5 +16,3 @@ const runMigrations = async () => {
process.exit(1);
}
};
runMigrations();

View File

@@ -0,0 +1,11 @@
import { db } from "./driver";
/**
* Runs a read query. For SQLite there is no replica/primary distinction,
* so the query is executed once against the database.
*/
export async function safeRead<T>(
query: (d: typeof db) => Promise<T>
): Promise<T> {
return query(db);
}

View File

@@ -1,13 +1,13 @@
import {
sqliteTable,
integer,
text,
real,
index
} from "drizzle-orm/sqlite-core";
import { InferSelectModel } from "drizzle-orm";
import { domains, orgs, targets, users, exitNodes, sessions } from "./schema";
import { metadata } from "@app/app/[orgId]/settings/layout";
import {
index,
integer,
primaryKey,
real,
sqliteTable,
text
} from "drizzle-orm/sqlite-core";
import { clients, domains, exitNodes, orgs, sessions, siteResources, sites, users } from "./schema";
export const certificates = sqliteTable("certificates", {
certId: integer("certId").primaryKey({ autoIncrement: true }),
@@ -71,13 +71,16 @@ export const subscriptions = sqliteTable("subscriptions", {
canceledAt: integer("canceledAt"),
createdAt: integer("createdAt").notNull(),
updatedAt: integer("updatedAt"),
billingCycleAnchor: integer("billingCycleAnchor")
version: integer("version"),
billingCycleAnchor: integer("billingCycleAnchor"),
type: text("type") // tier1, tier2, tier3, or license
});
export const subscriptionItems = sqliteTable("subscriptionItems", {
subscriptionItemId: integer("subscriptionItemId").primaryKey({
autoIncrement: true
}),
stripeSubscriptionItemId: text("stripeSubscriptionItemId"),
subscriptionId: text("subscriptionId")
.notNull()
.references(() => subscriptions.subscriptionId, {
@@ -85,6 +88,7 @@ export const subscriptionItems = sqliteTable("subscriptionItems", {
}),
planId: text("planId").notNull(),
priceId: text("priceId"),
featureId: text("featureId"),
meterId: text("meterId"),
unitAmount: real("unitAmount"),
tiers: text("tiers"),
@@ -127,6 +131,7 @@ export const limits = sqliteTable("limits", {
})
.notNull(),
value: real("value"),
override: integer("override", { mode: "boolean" }).default(false),
description: text("description")
});
@@ -203,6 +208,31 @@ export const loginPageOrg = sqliteTable("loginPageOrg", {
.references(() => orgs.orgId, { onDelete: "cascade" })
});
export const loginPageBranding = sqliteTable("loginPageBranding", {
loginPageBrandingId: integer("loginPageBrandingId").primaryKey({
autoIncrement: true
}),
logoUrl: text("logoUrl"),
logoWidth: integer("logoWidth").notNull(),
logoHeight: integer("logoHeight").notNull(),
primaryColor: text("primaryColor"),
resourceTitle: text("resourceTitle").notNull(),
resourceSubtitle: text("resourceSubtitle"),
orgTitle: text("orgTitle"),
orgSubtitle: text("orgSubtitle")
});
export const loginPageBrandingOrg = sqliteTable("loginPageBrandingOrg", {
loginPageBrandingId: integer("loginPageBrandingId")
.notNull()
.references(() => loginPageBranding.loginPageBrandingId, {
onDelete: "cascade"
}),
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" })
});
export const sessionTransferToken = sqliteTable("sessionTransferToken", {
token: text("token").primaryKey(),
sessionId: text("sessionId")
@@ -249,6 +279,7 @@ export const accessAuditLog = sqliteTable(
actor: text("actor"),
actorId: text("actorId"),
resourceId: integer("resourceId"),
siteResourceId: integer("siteResourceId"),
ip: text("ip"),
location: text("location"),
type: text("type").notNull(),
@@ -265,6 +296,109 @@ export const accessAuditLog = sqliteTable(
]
);
export const connectionAuditLog = sqliteTable(
"connectionAuditLog",
{
id: integer("id").primaryKey({ autoIncrement: true }),
sessionId: text("sessionId").notNull(),
siteResourceId: integer("siteResourceId").references(
() => siteResources.siteResourceId,
{ onDelete: "cascade" }
),
orgId: text("orgId").references(() => orgs.orgId, {
onDelete: "cascade"
}),
siteId: integer("siteId").references(() => sites.siteId, {
onDelete: "cascade"
}),
clientId: integer("clientId").references(() => clients.clientId, {
onDelete: "cascade"
}),
userId: text("userId").references(() => users.userId, {
onDelete: "cascade"
}),
sourceAddr: text("sourceAddr").notNull(),
destAddr: text("destAddr").notNull(),
protocol: text("protocol").notNull(),
startedAt: integer("startedAt").notNull(),
endedAt: integer("endedAt"),
bytesTx: integer("bytesTx"),
bytesRx: integer("bytesRx")
},
(table) => [
index("idx_accessAuditLog_startedAt").on(table.startedAt),
index("idx_accessAuditLog_org_startedAt").on(
table.orgId,
table.startedAt
),
index("idx_accessAuditLog_siteResourceId").on(table.siteResourceId)
]
);
export const approvals = sqliteTable("approvals", {
approvalId: integer("approvalId").primaryKey({ autoIncrement: true }),
timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds
orgId: text("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
})
.notNull(),
clientId: integer("clientId").references(() => clients.clientId, {
onDelete: "cascade"
}), // olms reference user devices clients
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
}),
decision: text("decision")
.$type<"approved" | "denied" | "pending">()
.default("pending")
.notNull(),
type: text("type")
.$type<"user_device" /*| 'proxy' // for later */>()
.notNull()
});
export const bannedEmails = sqliteTable("bannedEmails", {
email: text("email").primaryKey()
});
export const bannedIps = sqliteTable("bannedIps", {
ip: text("ip").primaryKey()
});
export const siteProvisioningKeys = sqliteTable("siteProvisioningKeys", {
siteProvisioningKeyId: text("siteProvisioningKeyId").primaryKey(),
name: text("name").notNull(),
siteProvisioningKeyHash: text("siteProvisioningKeyHash").notNull(),
lastChars: text("lastChars").notNull(),
createdAt: text("dateCreated").notNull(),
lastUsed: text("lastUsed"),
maxBatchSize: integer("maxBatchSize"), // null = no limit
numUsed: integer("numUsed").notNull().default(0),
validUntil: text("validUntil")
});
export const siteProvisioningKeyOrg = sqliteTable(
"siteProvisioningKeyOrg",
{
siteProvisioningKeyId: text("siteProvisioningKeyId")
.notNull()
.references(() => siteProvisioningKeys.siteProvisioningKeyId, {
onDelete: "cascade"
}),
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" })
},
(table) => [
primaryKey({
columns: [table.siteProvisioningKeyId, table.orgId]
})
]
);
export type Approval = InferSelectModel<typeof approvals>;
export type Limit = InferSelectModel<typeof limits>;
export type Account = InferSelectModel<typeof account>;
export type Certificate = InferSelectModel<typeof certificates>;
@@ -282,5 +416,7 @@ export type RemoteExitNodeSession = InferSelectModel<
>;
export type ExitNodeOrg = InferSelectModel<typeof exitNodeOrgs>;
export type LoginPage = InferSelectModel<typeof loginPage>;
export type LoginPageBranding = InferSelectModel<typeof loginPageBranding>;
export type ActionAuditLog = InferSelectModel<typeof actionAuditLog>;
export type AccessAuditLog = InferSelectModel<typeof accessAuditLog>;
export type ConnectionAuditLog = InferSelectModel<typeof connectionAuditLog>;

View File

@@ -1,7 +1,13 @@
import { randomUUID } from "crypto";
import { InferSelectModel } from "drizzle-orm";
import { sqliteTable, text, integer, index } from "drizzle-orm/sqlite-core";
import { no } from "zod/v4/locales";
import {
index,
integer,
primaryKey,
sqliteTable,
text,
unique
} from "drizzle-orm/sqlite-core";
export const domains = sqliteTable("domains", {
domainId: text("domainId").primaryKey(),
@@ -14,7 +20,8 @@ export const domains = sqliteTable("domains", {
failed: integer("failed", { mode: "boolean" }).notNull().default(false),
tries: integer("tries").notNull().default(0),
certResolver: text("certResolver"),
preferWildcardCert: integer("preferWildcardCert", { mode: "boolean" })
preferWildcardCert: integer("preferWildcardCert", { mode: "boolean" }),
errorMessage: text("errorMessage")
});
export const dnsRecords = sqliteTable("dnsRecords", {
@@ -46,7 +53,14 @@ export const orgs = sqliteTable("orgs", {
.default(0),
settingsLogRetentionDaysAction: integer("settingsLogRetentionDaysAction") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year
.notNull()
.default(0)
.default(0),
settingsLogRetentionDaysConnection: integer("settingsLogRetentionDaysConnection") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year
.notNull()
.default(0),
sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format)
sshCaPublicKey: text("sshCaPublicKey"), // SSH CA public key (OpenSSH format)
isBillingOrg: integer("isBillingOrg", { mode: "boolean" }),
billingOrgId: text("billingOrgId")
});
export const userDomains = sqliteTable("userDomains", {
@@ -86,6 +100,7 @@ export const sites = sqliteTable("sites", {
lastBandwidthUpdate: text("lastBandwidthUpdate"),
type: text("type").notNull(), // "newt" or "wireguard"
online: integer("online", { mode: "boolean" }).notNull().default(false),
lastPing: integer("lastPing"),
// exit node stuff that is how to connect to the site when it has a wg server
address: text("address"), // this is the address of the wireguard interface in newt
@@ -144,7 +159,20 @@ export const resources = sqliteTable("resources", {
proxyProtocol: integer("proxyProtocol", { mode: "boolean" })
.notNull()
.default(false),
proxyProtocolVersion: integer("proxyProtocolVersion").default(1)
proxyProtocolVersion: integer("proxyProtocolVersion").default(1),
maintenanceModeEnabled: integer("maintenanceModeEnabled", {
mode: "boolean"
})
.notNull()
.default(false),
maintenanceModeType: text("maintenanceModeType", {
enum: ["forced", "automatic"]
}).default("forced"), // "forced" = always show, "automatic" = only when down
maintenanceTitle: text("maintenanceTitle"),
maintenanceMessage: text("maintenanceMessage"),
maintenanceEstimatedTime: text("maintenanceEstimatedTime"),
postAuthPath: text("postAuthPath")
});
export const targets = sqliteTable("targets", {
@@ -195,7 +223,9 @@ export const targetHealthCheck = sqliteTable("targetHealthCheck", {
}).default(true),
hcMethod: text("hcMethod").default("GET"),
hcStatus: integer("hcStatus"), // http code
hcHealth: text("hcHealth").default("unknown"), // "unknown", "healthy", "unhealthy"
hcHealth: text("hcHealth")
.$type<"unknown" | "healthy" | "unhealthy">()
.default("unknown"), // "unknown", "healthy", "unhealthy"
hcTlsServerName: text("hcTlsServerName")
});
@@ -227,14 +257,23 @@ export const siteResources = sqliteTable("siteResources", {
.references(() => orgs.orgId, { onDelete: "cascade" }),
niceId: text("niceId").notNull(),
name: text("name").notNull(),
mode: text("mode").notNull(), // "host" | "cidr" | "port"
mode: text("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port"
protocol: text("protocol"), // only for port mode
proxyPort: integer("proxyPort"), // only for port mode
destinationPort: integer("destinationPort"), // only for port mode
destination: text("destination").notNull(), // ip, cidr, hostname
enabled: integer("enabled", { mode: "boolean" }).notNull().default(true),
alias: text("alias"),
aliasAddress: text("aliasAddress")
aliasAddress: text("aliasAddress"),
tcpPortRangeString: text("tcpPortRangeString").notNull().default("*"),
udpPortRangeString: text("udpPortRangeString").notNull().default("*"),
disableIcmp: integer("disableIcmp", { mode: "boolean" })
.notNull()
.default(false),
authDaemonPort: integer("authDaemonPort").default(22123),
authDaemonMode: text("authDaemonMode")
.$type<"site" | "remote">()
.default("site")
});
export const clientSiteResources = sqliteTable("clientSiteResources", {
@@ -287,6 +326,9 @@ export const users = sqliteTable("user", {
dateCreated: text("dateCreated").notNull(),
termsAcceptedTimestamp: text("termsAcceptedTimestamp"),
termsVersion: text("termsVersion"),
marketingEmailConsent: integer("marketingEmailConsent", {
mode: "boolean"
}).default(false),
serverAdmin: integer("serverAdmin", { mode: "boolean" })
.notNull()
.default(false),
@@ -362,7 +404,12 @@ export const clients = sqliteTable("clients", {
type: text("type").notNull(), // "olm"
online: integer("online", { mode: "boolean" }).notNull().default(false),
// endpoint: text("endpoint"),
lastHolePunch: integer("lastHolePunch")
lastHolePunch: integer("lastHolePunch"),
archived: integer("archived", { mode: "boolean" }).notNull().default(false),
blocked: integer("blocked", { mode: "boolean" }).notNull().default(false),
approvalState: text("approvalState").$type<
"pending" | "approved" | "denied"
>()
});
export const clientSitesAssociationsCache = sqliteTable(
@@ -374,6 +421,9 @@ export const clientSitesAssociationsCache = sqliteTable(
isRelayed: integer("isRelayed", { mode: "boolean" })
.notNull()
.default(false),
isJitMode: integer("isJitMode", { mode: "boolean" })
.notNull()
.default(false),
endpoint: text("endpoint"),
publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
}
@@ -402,7 +452,160 @@ export const olms = sqliteTable("olms", {
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
}),
archived: integer("archived", { mode: "boolean" }).notNull().default(false)
});
export const currentFingerprint = sqliteTable("currentFingerprint", {
fingerprintId: integer("id").primaryKey({ autoIncrement: true }),
olmId: text("olmId")
.references(() => olms.olmId, { onDelete: "cascade" })
.notNull(),
firstSeen: integer("firstSeen").notNull(),
lastSeen: integer("lastSeen").notNull(),
lastCollectedAt: integer("lastCollectedAt").notNull(),
username: text("username"),
hostname: text("hostname"),
platform: text("platform"),
osVersion: text("osVersion"),
kernelVersion: text("kernelVersion"),
arch: text("arch"),
deviceModel: text("deviceModel"),
serialNumber: text("serialNumber"),
platformFingerprint: text("platformFingerprint"),
// Platform-agnostic checks
biometricsEnabled: integer("biometricsEnabled", { mode: "boolean" })
.notNull()
.default(false),
diskEncrypted: integer("diskEncrypted", { mode: "boolean" })
.notNull()
.default(false),
firewallEnabled: integer("firewallEnabled", { mode: "boolean" })
.notNull()
.default(false),
autoUpdatesEnabled: integer("autoUpdatesEnabled", { mode: "boolean" })
.notNull()
.default(false),
tpmAvailable: integer("tpmAvailable", { mode: "boolean" })
.notNull()
.default(false),
// Windows-specific posture check information
windowsAntivirusEnabled: integer("windowsAntivirusEnabled", {
mode: "boolean"
})
.notNull()
.default(false),
// macOS-specific posture check information
macosSipEnabled: integer("macosSipEnabled", { mode: "boolean" })
.notNull()
.default(false),
macosGatekeeperEnabled: integer("macosGatekeeperEnabled", {
mode: "boolean"
})
.notNull()
.default(false),
macosFirewallStealthMode: integer("macosFirewallStealthMode", {
mode: "boolean"
})
.notNull()
.default(false),
// Linux-specific posture check information
linuxAppArmorEnabled: integer("linuxAppArmorEnabled", { mode: "boolean" })
.notNull()
.default(false),
linuxSELinuxEnabled: integer("linuxSELinuxEnabled", {
mode: "boolean"
})
.notNull()
.default(false)
});
export const fingerprintSnapshots = sqliteTable("fingerprintSnapshots", {
snapshotId: integer("id").primaryKey({ autoIncrement: true }),
fingerprintId: integer("fingerprintId").references(
() => currentFingerprint.fingerprintId,
{
onDelete: "set null"
}
),
username: text("username"),
hostname: text("hostname"),
platform: text("platform"),
osVersion: text("osVersion"),
kernelVersion: text("kernelVersion"),
arch: text("arch"),
deviceModel: text("deviceModel"),
serialNumber: text("serialNumber"),
platformFingerprint: text("platformFingerprint"),
// Platform-agnostic checks
biometricsEnabled: integer("biometricsEnabled", { mode: "boolean" })
.notNull()
.default(false),
diskEncrypted: integer("diskEncrypted", { mode: "boolean" })
.notNull()
.default(false),
firewallEnabled: integer("firewallEnabled", { mode: "boolean" })
.notNull()
.default(false),
autoUpdatesEnabled: integer("autoUpdatesEnabled", { mode: "boolean" })
.notNull()
.default(false),
tpmAvailable: integer("tpmAvailable", { mode: "boolean" })
.notNull()
.default(false),
// Windows-specific posture check information
windowsAntivirusEnabled: integer("windowsAntivirusEnabled", {
mode: "boolean"
})
.notNull()
.default(false),
// macOS-specific posture check information
macosSipEnabled: integer("macosSipEnabled", { mode: "boolean" })
.notNull()
.default(false),
macosGatekeeperEnabled: integer("macosGatekeeperEnabled", {
mode: "boolean"
})
.notNull()
.default(false),
macosFirewallStealthMode: integer("macosFirewallStealthMode", {
mode: "boolean"
})
.notNull()
.default(false),
// Linux-specific posture check information
linuxAppArmorEnabled: integer("linuxAppArmorEnabled", { mode: "boolean" })
.notNull()
.default(false),
linuxSELinuxEnabled: integer("linuxSELinuxEnabled", {
mode: "boolean"
})
.notNull()
.default(false),
hash: text("hash").notNull(),
collectedAt: integer("collectedAt").notNull()
});
export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", {
@@ -450,13 +653,11 @@ export const userOrgs = sqliteTable("userOrgs", {
onDelete: "cascade"
})
.notNull(),
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId),
isOwner: integer("isOwner", { mode: "boolean" }).notNull().default(false),
autoProvisioned: integer("autoProvisioned", {
mode: "boolean"
}).default(false)
}).default(false),
pamUsername: text("pamUsername") // cleaned username for ssh and such
});
export const emailVerificationCodes = sqliteTable("emailVerificationCodes", {
@@ -494,9 +695,34 @@ export const roles = sqliteTable("roles", {
.notNull(),
isAdmin: integer("isAdmin", { mode: "boolean" }),
name: text("name").notNull(),
description: text("description")
description: text("description"),
requireDeviceApproval: integer("requireDeviceApproval", {
mode: "boolean"
}).default(false),
sshSudoMode: text("sshSudoMode").default("none"), // "none" | "full" | "commands"
sshSudoCommands: text("sshSudoCommands").default("[]"),
sshCreateHomeDir: integer("sshCreateHomeDir", { mode: "boolean" }).default(
true
),
sshUnixGroups: text("sshUnixGroups").default("[]")
});
export const userOrgRoles = sqliteTable(
"userOrgRoles",
{
userId: text("userId")
.notNull()
.references(() => users.userId, { onDelete: "cascade" }),
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }),
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" })
},
(t) => [unique().on(t.userId, t.orgId, t.roleId)]
);
export const roleActions = sqliteTable("roleActions", {
roleId: integer("roleId")
.notNull()
@@ -582,12 +808,22 @@ export const userInvites = sqliteTable("userInvites", {
.references(() => orgs.orgId, { onDelete: "cascade" }),
email: text("email").notNull(),
expiresAt: integer("expiresAt").notNull(),
tokenHash: text("token").notNull(),
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" })
tokenHash: text("token").notNull()
});
export const userInviteRoles = sqliteTable(
"userInviteRoles",
{
inviteId: text("inviteId")
.notNull()
.references(() => userInvites.inviteId, { onDelete: "cascade" }),
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" })
},
(t) => [primaryKey({ columns: [t.inviteId, t.roleId] })]
);
export const resourcePincode = sqliteTable("resourcePincode", {
pincodeId: integer("pincodeId").primaryKey({
autoIncrement: true
@@ -619,6 +855,26 @@ export const resourceHeaderAuth = sqliteTable("resourceHeaderAuth", {
headerAuthHash: text("headerAuthHash").notNull()
});
export const resourceHeaderAuthExtendedCompatibility = sqliteTable(
"resourceHeaderAuthExtendedCompatibility",
{
headerAuthExtendedCompatibilityId: integer(
"headerAuthExtendedCompatibilityId"
).primaryKey({
autoIncrement: true
}),
resourceId: integer("resourceId")
.notNull()
.references(() => resources.resourceId, { onDelete: "cascade" }),
extendedCompatibilityIsActivated: integer(
"extendedCompatibilityIsActivated",
{ mode: "boolean" }
)
.notNull()
.default(true)
}
);
export const resourceAccessToken = sqliteTable("resourceAccessToken", {
accessTokenId: text("accessTokenId").primaryKey(),
orgId: text("orgId")
@@ -733,7 +989,8 @@ export const idp = sqliteTable("idp", {
mode: "boolean"
})
.notNull()
.default(false)
.default(false),
tags: text("tags")
});
// Identity Provider OAuth Configuration
@@ -874,6 +1131,16 @@ export const deviceWebAuthCodes = sqliteTable("deviceWebAuthCodes", {
})
});
export const roundTripMessageTracker = sqliteTable("roundTripMessageTracker", {
messageId: integer("messageId").primaryKey({ autoIncrement: true }),
wsClientId: text("clientId"),
messageType: text("messageType"),
sentAt: integer("sentAt").notNull(),
receivedAt: integer("receivedAt"),
error: text("error"),
complete: integer("complete", { mode: "boolean" }).notNull().default(false)
});
export type Org = InferSelectModel<typeof orgs>;
export type User = InferSelectModel<typeof users>;
export type Site = InferSelectModel<typeof sites>;
@@ -899,11 +1166,16 @@ export type UserSite = InferSelectModel<typeof userSites>;
export type RoleResource = InferSelectModel<typeof roleResources>;
export type UserResource = InferSelectModel<typeof userResources>;
export type UserInvite = InferSelectModel<typeof userInvites>;
export type UserInviteRole = InferSelectModel<typeof userInviteRoles>;
export type UserOrg = InferSelectModel<typeof userOrgs>;
export type UserOrgRole = InferSelectModel<typeof userOrgRoles>;
export type ResourceSession = InferSelectModel<typeof resourceSessions>;
export type ResourcePincode = InferSelectModel<typeof resourcePincode>;
export type ResourcePassword = InferSelectModel<typeof resourcePassword>;
export type ResourceHeaderAuth = InferSelectModel<typeof resourceHeaderAuth>;
export type ResourceHeaderAuthExtendedCompatibility = InferSelectModel<
typeof resourceHeaderAuthExtendedCompatibility
>;
export type ResourceOtp = InferSelectModel<typeof resourceOtp>;
export type ResourceAccessToken = InferSelectModel<typeof resourceAccessToken>;
export type ResourceWhitelist = InferSelectModel<typeof resourceWhitelist>;
@@ -932,3 +1204,6 @@ export type SecurityKey = InferSelectModel<typeof securityKeys>;
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
export type RoundTripMessageTracker = InferSelectModel<
typeof roundTripMessageTracker
>;

View File

@@ -10,6 +10,7 @@ export async function sendEmail(
from: string | undefined;
to: string | undefined;
subject: string;
replyTo?: string;
}
) {
if (!emailClient) {
@@ -32,6 +33,7 @@ export async function sendEmail(
address: opts.from
},
to: opts.to,
replyTo: opts.replyTo,
subject: opts.subject,
html: emailHtml
});

View File

@@ -0,0 +1,118 @@
import React from "react";
import { Body, Head, Html, Preview, Tailwind } from "@react-email/components";
import { themeColors } from "./lib/theme";
import {
EmailContainer,
EmailFooter,
EmailGreeting,
EmailHeading,
EmailInfoSection,
EmailLetterHead,
EmailSection,
EmailSignature,
EmailText
} from "./components/Email";
import CopyCodeBox from "./components/CopyCodeBox";
import ButtonLink from "./components/ButtonLink";
type EnterpriseEditionKeyGeneratedProps = {
keyValue: string;
personalUseOnly: boolean;
users: number;
sites: number;
modifySubscriptionLink?: string;
};
export const EnterpriseEditionKeyGenerated = ({
keyValue,
personalUseOnly,
users,
sites,
modifySubscriptionLink
}: EnterpriseEditionKeyGeneratedProps) => {
const previewText = personalUseOnly
? "Your Enterprise Edition key for personal use is ready"
: "Thank you for your purchase — your Enterprise Edition key is ready";
return (
<Html>
<Head />
<Preview>{previewText}</Preview>
<Tailwind config={themeColors}>
<Body className="font-sans bg-gray-50">
<EmailContainer>
<EmailLetterHead />
<EmailGreeting>Hi there,</EmailGreeting>
{personalUseOnly ? (
<EmailText>
Your Enterprise Edition license key has been
generated. Qualifying users can use the
Enterprise Edition for free for{" "}
<strong>personal use only</strong>.
</EmailText>
) : (
<>
<EmailText>
Thank you for your purchase. Your Enterprise
Edition license key is ready. Below are the
terms of your license.
</EmailText>
<EmailInfoSection
title="License details"
items={[
{
label: "Licensed users",
value: users
},
{
label: "Licensed sites",
value: sites
}
]}
/>
{modifySubscriptionLink && (
<EmailSection>
<ButtonLink
href={modifySubscriptionLink}
>
Modify subscription
</ButtonLink>
</EmailSection>
)}
</>
)}
<EmailSection>
<EmailText>Your license key:</EmailText>
<CopyCodeBox
text={keyValue}
hint="Copy this key and use it when activating Enterprise Edition on your Pangolin host."
/>
</EmailSection>
<EmailText>
If you need to purchase additional license keys or
modify your existing license, please reach out to
our support team at{" "}
<a
href="mailto:support@pangolin.net"
className="text-primary font-medium"
>
support@pangolin.net
</a>
.
</EmailText>
<EmailFooter>
<EmailSignature />
</EmailFooter>
</EmailContainer>
</Body>
</Tailwind>
</Html>
);
};
export default EnterpriseEditionKeyGenerated;

View File

@@ -1,6 +1,14 @@
import React from "react";
export default function CopyCodeBox({ text }: { text: string }) {
const DEFAULT_HINT = "Copy and paste this code when prompted";
export default function CopyCodeBox({
text,
hint
}: {
text: string;
hint?: string;
}) {
return (
<div className="inline-block">
<div className="bg-gray-50 border border-gray-200 rounded-lg px-6 py-4 mx-auto">
@@ -8,9 +16,7 @@ export default function CopyCodeBox({ text }: { text: string }) {
{text}
</span>
</div>
<p className="text-xs text-gray-500 mt-2">
Copy and paste this code when prompted
</p>
<p className="text-xs text-gray-500 mt-2">{hint ?? DEFAULT_HINT}</p>
</div>
);
}

View File

@@ -74,7 +74,7 @@ declare global {
session: Session;
userOrg?: UserOrg;
apiKeyOrg?: ApiKeyOrg;
userOrgRoleId?: number;
userOrgRoleIds?: number[];
userOrgId?: string;
userOrgIds?: string[];
remoteExitNode?: RemoteExitNode;

View File

@@ -17,6 +17,7 @@ import fs from "fs";
import path from "path";
import { APP_PATH } from "./lib/consts";
import yaml from "js-yaml";
import { z } from "zod";
const dev = process.env.ENVIRONMENT !== "prod";
const externalPort = config.getRawConfig().server.integration_port;
@@ -38,12 +39,24 @@ export function createIntegrationApiServer() {
apiServer.use(cookieParser());
apiServer.use(express.json());
const openApiDocumentation = getOpenApiDocumentation();
apiServer.use(
"/v1/docs",
swaggerUi.serve,
swaggerUi.setup(getOpenApiDocumentation())
swaggerUi.setup(openApiDocumentation)
);
// Unauthenticated OpenAPI spec endpoints
apiServer.get("/v1/openapi.json", (_req, res) => {
res.json(openApiDocumentation);
});
apiServer.get("/v1/openapi.yaml", (_req, res) => {
const yamlOutput = yaml.dump(openApiDocumentation);
res.type("application/yaml").send(yamlOutput);
});
// API routes
const prefix = `/v1`;
apiServer.use(logIncomingMiddleware);
@@ -75,16 +88,6 @@ function getOpenApiDocumentation() {
}
);
for (const def of registry.definitions) {
if (def.type === "route") {
def.route.security = [
{
[bearerAuth.name]: []
}
];
}
}
registry.registerPath({
method: "get",
path: "/",
@@ -94,6 +97,74 @@ function getOpenApiDocumentation() {
responses: {}
});
registry.registerPath({
method: "get",
path: "/openapi.json",
description: "Get OpenAPI specification as JSON",
tags: [],
request: {},
responses: {
"200": {
description: "OpenAPI specification as JSON",
content: {
"application/json": {
schema: {
type: "object"
}
}
}
}
}
});
registry.registerPath({
method: "get",
path: "/openapi.yaml",
description: "Get OpenAPI specification as YAML",
tags: [],
request: {},
responses: {
"200": {
description: "OpenAPI specification as YAML",
content: {
"application/yaml": {
schema: {
type: "string"
}
}
}
}
}
});
for (const def of registry.definitions) {
if (def.type === "route") {
def.route.security = [
{
[bearerAuth.name]: []
}
];
// Ensure every route has a generic JSON response schema so Swagger UI can render responses
const existingResponses = def.route.responses;
const hasExistingResponses =
existingResponses && Object.keys(existingResponses).length > 0;
if (!hasExistingResponses) {
def.route.responses = {
"*": {
description: "",
content: {
"application/json": {
schema: z.object({})
}
}
}
};
}
}
}
const generator = new OpenApiGeneratorV3(registry.definitions);
const generated = generator.generateDocument({
@@ -105,11 +176,13 @@ function getOpenApiDocumentation() {
servers: [{ url: "/v1" }]
});
// convert to yaml and save to file
const outputPath = path.join(APP_PATH, "openapi.yaml");
const yamlOutput = yaml.dump(generated);
fs.writeFileSync(outputPath, yamlOutput, "utf8");
logger.info(`OpenAPI documentation saved to ${outputPath}`);
if (!process.env.DISABLE_GEN_OPENAPI) {
// convert to yaml and save to file
const outputPath = path.join(APP_PATH, "openapi.yaml");
const yamlOutput = yaml.dump(generated);
fs.writeFileSync(outputPath, yamlOutput, "utf8");
logger.info(`OpenAPI documentation saved to ${outputPath}`);
}
return generated;
}

View File

@@ -16,6 +16,11 @@ const internalPort = config.getRawConfig().server.internal_port;
export function createInternalServer() {
const internalServer = express();
const trustProxy = config.getRawConfig().server.trust_proxy;
if (trustProxy) {
internalServer.set("trust proxy", trustProxy);
}
internalServer.use(helmet());
internalServer.use(cors());
internalServer.use(stripDuplicateSesions);

View File

@@ -1,30 +1,44 @@
import Stripe from "stripe";
export enum FeatureId {
SITE_UPTIME = "siteUptime",
USERS = "users",
SITES = "sites",
EGRESS_DATA_MB = "egressDataMb",
DOMAINS = "domains",
REMOTE_EXIT_NODES = "remoteExitNodes"
REMOTE_EXIT_NODES = "remoteExitNodes",
ORGINIZATIONS = "organizations",
TIER1 = "tier1"
}
export const FeatureMeterIds: Record<FeatureId, string> = {
[FeatureId.SITE_UPTIME]: "mtr_61Srrej5wUJuiTWgo41D3Ee2Ir7WmDLU",
[FeatureId.USERS]: "mtr_61SrreISyIWpwUNGR41D3Ee2Ir7WmQro",
[FeatureId.EGRESS_DATA_MB]: "mtr_61Srreh9eWrExDSCe41D3Ee2Ir7Wm5YW",
[FeatureId.DOMAINS]: "mtr_61Ss9nIKDNMw0LDRU41D3Ee2Ir7WmRPU",
[FeatureId.REMOTE_EXIT_NODES]: "mtr_61T86UXnfxTVXy9sD41D3Ee2Ir7WmFTE"
export async function getFeatureDisplayName(featureId: FeatureId): Promise<string> {
switch (featureId) {
case FeatureId.USERS:
return "Users";
case FeatureId.SITES:
return "Sites";
case FeatureId.EGRESS_DATA_MB:
return "Egress Data (MB)";
case FeatureId.DOMAINS:
return "Domains";
case FeatureId.REMOTE_EXIT_NODES:
return "Remote Exit Nodes";
case FeatureId.ORGINIZATIONS:
return "Organizations";
case FeatureId.TIER1:
return "Home Lab";
default:
return featureId;
}
}
// this is from the old system
export const FeatureMeterIds: Partial<Record<FeatureId, string>> = { // right now we are not charging for any data
// [FeatureId.EGRESS_DATA_MB]: "mtr_61Srreh9eWrExDSCe41D3Ee2Ir7Wm5YW"
};
export const FeatureMeterIdsSandbox: Record<FeatureId, string> = {
[FeatureId.SITE_UPTIME]: "mtr_test_61Snh3cees4w60gv841DCpkOb237BDEu",
[FeatureId.USERS]: "mtr_test_61Sn5fLtq1gSfRkyA41DCpkOb237B6au",
[FeatureId.EGRESS_DATA_MB]: "mtr_test_61Snh2a2m6qome5Kv41DCpkOb237B3dQ",
[FeatureId.DOMAINS]: "mtr_test_61SsA8qrdAlgPpFRQ41DCpkOb237BGts",
[FeatureId.REMOTE_EXIT_NODES]: "mtr_test_61T86Vqmwa3D9ra3341DCpkOb237B94K"
export const FeatureMeterIdsSandbox: Partial<Record<FeatureId, string>> = {
// [FeatureId.EGRESS_DATA_MB]: "mtr_test_61Snh2a2m6qome5Kv41DCpkOb237B3dQ"
};
export function getFeatureMeterId(featureId: FeatureId): string {
export function getFeatureMeterId(featureId: FeatureId): string | undefined {
if (
process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true"
@@ -43,45 +57,81 @@ export function getFeatureIdByMetricId(
)?.[0];
}
export type FeaturePriceSet = {
[key in Exclude<FeatureId, FeatureId.DOMAINS>]: string;
} & {
[FeatureId.DOMAINS]?: string; // Optional since domains are not billed
export type FeaturePriceSet = Partial<Record<FeatureId, string>>;
export const tier1FeaturePriceSet: FeaturePriceSet = {
[FeatureId.TIER1]: "price_1SzVE3D3Ee2Ir7Wm6wT5Dl3G"
};
export const standardFeaturePriceSet: FeaturePriceSet = {
// Free tier matches the freeLimitSet
[FeatureId.SITE_UPTIME]: "price_1RrQc4D3Ee2Ir7WmaJGZ3MtF",
[FeatureId.USERS]: "price_1RrQeJD3Ee2Ir7WmgveP3xea",
[FeatureId.EGRESS_DATA_MB]: "price_1RrQXFD3Ee2Ir7WmvGDlgxQk",
// [FeatureId.DOMAINS]: "price_1Rz3tMD3Ee2Ir7Wm5qLeASzC",
[FeatureId.REMOTE_EXIT_NODES]: "price_1S46weD3Ee2Ir7Wm94KEHI4h"
export const tier1FeaturePriceSetSandbox: FeaturePriceSet = {
[FeatureId.TIER1]: "price_1SxgpPDCpkOb237Bfo4rIsoT"
};
export const standardFeaturePriceSetSandbox: FeaturePriceSet = {
// Free tier matches the freeLimitSet
[FeatureId.SITE_UPTIME]: "price_1RefFBDCpkOb237BPrKZ8IEU",
[FeatureId.USERS]: "price_1ReNa4DCpkOb237Bc67G5muF",
[FeatureId.EGRESS_DATA_MB]: "price_1Rfp9LDCpkOb237BwuN5Oiu0",
// [FeatureId.DOMAINS]: "price_1Ryi88DCpkOb237B2D6DM80b",
[FeatureId.REMOTE_EXIT_NODES]: "price_1RyiZvDCpkOb237BXpmoIYJL"
};
export function getStandardFeaturePriceSet(): FeaturePriceSet {
export function getTier1FeaturePriceSet(): FeaturePriceSet {
if (
process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true"
) {
return standardFeaturePriceSet;
return tier1FeaturePriceSet;
} else {
return standardFeaturePriceSetSandbox;
return tier1FeaturePriceSetSandbox;
}
}
export function getLineItems(
featurePriceSet: FeaturePriceSet
): Stripe.Checkout.SessionCreateParams.LineItem[] {
return Object.entries(featurePriceSet).map(([featureId, priceId]) => ({
price: priceId
}));
export const tier2FeaturePriceSet: FeaturePriceSet = {
[FeatureId.USERS]: "price_1SzVCcD3Ee2Ir7Wmn6U3KvPN"
};
export const tier2FeaturePriceSetSandbox: FeaturePriceSet = {
[FeatureId.USERS]: "price_1SxaEHDCpkOb237BD9lBkPiR"
};
export function getTier2FeaturePriceSet(): FeaturePriceSet {
if (
process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true"
) {
return tier2FeaturePriceSet;
} else {
return tier2FeaturePriceSetSandbox;
}
}
export const tier3FeaturePriceSet: FeaturePriceSet = {
[FeatureId.USERS]: "price_1SzVDKD3Ee2Ir7WmPtOKNusv"
};
export const tier3FeaturePriceSetSandbox: FeaturePriceSet = {
[FeatureId.USERS]: "price_1SxaEODCpkOb237BiXdCBSfs"
};
export function getTier3FeaturePriceSet(): FeaturePriceSet {
if (
process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true"
) {
return tier3FeaturePriceSet;
} else {
return tier3FeaturePriceSetSandbox;
}
}
export function getFeatureIdByPriceId(priceId: string): FeatureId | undefined {
// Check all feature price sets
const allPriceSets = [
getTier1FeaturePriceSet(),
getTier2FeaturePriceSet(),
getTier3FeaturePriceSet()
];
for (const priceSet of allPriceSets) {
const entry = (Object.entries(priceSet) as [FeatureId, string][]).find(
([_, price]) => price === priceId
);
if (entry) {
return entry[0];
}
}
return undefined;
}

View File

@@ -0,0 +1,25 @@
import Stripe from "stripe";
import { FeatureId, FeaturePriceSet } from "./features";
import { usageService } from "./usageService";
export async function getLineItems(
featurePriceSet: FeaturePriceSet,
orgId: string,
): Promise<Stripe.Checkout.SessionCreateParams.LineItem[]> {
const users = await usageService.getUsage(orgId, FeatureId.USERS);
return Object.entries(featurePriceSet).map(([featureId, priceId]) => {
let quantity: number | undefined;
if (featureId === FeatureId.USERS) {
quantity = users?.instantaneousValue || 1;
} else if (featureId === FeatureId.TIER1) {
quantity = 1;
}
return {
price: priceId,
quantity: quantity
};
});
}

View File

@@ -0,0 +1,37 @@
export enum LicenseId {
SMALL_LICENSE = "small_license",
BIG_LICENSE = "big_license"
}
export type LicensePriceSet = {
[key in LicenseId]: string;
};
export const licensePriceSet: LicensePriceSet = {
// Free license matches the freeLimitSet
[LicenseId.SMALL_LICENSE]: "price_1SxKHiD3Ee2Ir7WmvtEh17A8",
[LicenseId.BIG_LICENSE]: "price_1SxKHiD3Ee2Ir7WmMUiP0H6Y"
};
export const licensePriceSetSandbox: LicensePriceSet = {
// Free license matches the freeLimitSet
// when matching license the keys closer to 0 index are matched first so list the licenses in descending order of value
[LicenseId.SMALL_LICENSE]: "price_1SxDwuDCpkOb237Bz0yTiOgN",
[LicenseId.BIG_LICENSE]: "price_1SxDy0DCpkOb237BWJxrxYkl"
};
export function getLicensePriceSet(
environment?: string,
sandbox_mode?: boolean
): LicensePriceSet {
if (
(process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true") ||
(environment === "prod" && sandbox_mode !== true)
) {
// THIS GETS LOADED CLIENT SIDE AND SERVER SIDE
return licensePriceSet;
} else {
return licensePriceSetSandbox;
}
}

View File

@@ -1,50 +1,70 @@
import { FeatureId } from "./features";
export type LimitSet = {
export type LimitSet = Partial<{
[key in FeatureId]: {
value: number | null; // null indicates no limit
description?: string;
};
};
export const sandboxLimitSet: LimitSet = {
[FeatureId.SITE_UPTIME]: { value: 2880, description: "Sandbox limit" }, // 1 site up for 2 days
[FeatureId.USERS]: { value: 1, description: "Sandbox limit" },
[FeatureId.EGRESS_DATA_MB]: { value: 1000, description: "Sandbox limit" }, // 1 GB
[FeatureId.DOMAINS]: { value: 0, description: "Sandbox limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" }
};
}>;
export const freeLimitSet: LimitSet = {
[FeatureId.SITE_UPTIME]: { value: 46080, description: "Free tier limit" }, // 1 site up for 32 days
[FeatureId.USERS]: { value: 3, description: "Free tier limit" },
[FeatureId.EGRESS_DATA_MB]: {
value: 25000,
description: "Free tier limit"
}, // 25 GB
[FeatureId.DOMAINS]: { value: 3, description: "Free tier limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Free tier limit" }
[FeatureId.SITES]: { value: 5, description: "Basic limit" },
[FeatureId.USERS]: { value: 5, description: "Basic limit" },
[FeatureId.DOMAINS]: { value: 5, description: "Basic limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Basic limit" },
[FeatureId.ORGINIZATIONS]: { value: 1, description: "Basic limit" },
};
export const subscribedLimitSet: LimitSet = {
[FeatureId.SITE_UPTIME]: {
value: 2232000,
description: "Contact us to increase soft limit."
}, // 50 sites up for 31 days
export const tier1LimitSet: LimitSet = {
[FeatureId.USERS]: { value: 7, description: "Home limit" },
[FeatureId.SITES]: { value: 10, description: "Home limit" },
[FeatureId.DOMAINS]: { value: 10, description: "Home limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Home limit" },
[FeatureId.ORGINIZATIONS]: { value: 1, description: "Home limit" },
};
export const tier2LimitSet: LimitSet = {
[FeatureId.USERS]: {
value: 150,
description: "Contact us to increase soft limit."
value: 100,
description: "Team limit"
},
[FeatureId.SITES]: {
value: 50,
description: "Team limit"
},
[FeatureId.EGRESS_DATA_MB]: {
value: 12000000,
description: "Contact us to increase soft limit."
}, // 12000 GB
[FeatureId.DOMAINS]: {
value: 25,
description: "Contact us to increase soft limit."
value: 50,
description: "Team limit"
},
[FeatureId.REMOTE_EXIT_NODES]: {
value: 5,
description: "Contact us to increase soft limit."
value: 3,
description: "Team limit"
},
[FeatureId.ORGINIZATIONS]: {
value: 1,
description: "Team limit"
}
};
export const tier3LimitSet: LimitSet = {
[FeatureId.USERS]: {
value: 500,
description: "Business limit"
},
[FeatureId.SITES]: {
value: 250,
description: "Business limit"
},
[FeatureId.DOMAINS]: {
value: 100,
description: "Business limit"
},
[FeatureId.REMOTE_EXIT_NODES]: {
value: 20,
description: "Business limit"
},
[FeatureId.ORGINIZATIONS]: {
value: 5,
description: "Business limit"
},
};

View File

@@ -2,6 +2,7 @@ import { db, limits } from "@server/db";
import { and, eq } from "drizzle-orm";
import { LimitSet } from "./limitSet";
import { FeatureId } from "./features";
import logger from "@server/logger";
class LimitService {
async applyLimitSetToOrg(orgId: string, limitSet: LimitSet): Promise<void> {
@@ -13,6 +14,21 @@ class LimitService {
for (const [featureId, entry] of limitEntries) {
const limitId = `${orgId}-${featureId}`;
const { value, description } = entry;
// get the limit first
const [limit] = await trx
.select()
.from(limits)
.where(eq(limits.limitId, limitId))
.limit(1);
// check if its overriden
if (limit && limit.override) {
logger.debug(
`Skipping limit ${limitId} for org ${orgId} since it is overridden...`
);
continue;
}
await trx
.insert(limits)
.values({ limitId, orgId, featureId, value, description });

View File

@@ -0,0 +1,58 @@
import { Tier } from "@server/types/Tiers";
export enum TierFeature {
OrgOidc = "orgOidc",
LoginPageDomain = "loginPageDomain", // handle downgrade by removing custom domain
DeviceApprovals = "deviceApprovals", // handle downgrade by disabling device approvals
LoginPageBranding = "loginPageBranding", // handle downgrade by setting to default branding
LogExport = "logExport",
AccessLogs = "accessLogs", // set the retention period to none on downgrade
ActionLogs = "actionLogs", // set the retention period to none on downgrade
ConnectionLogs = "connectionLogs",
RotateCredentials = "rotateCredentials",
MaintencePage = "maintencePage", // handle downgrade
DevicePosture = "devicePosture",
TwoFactorEnforcement = "twoFactorEnforcement", // handle downgrade by setting to optional
SessionDurationPolicies = "sessionDurationPolicies", // handle downgrade by setting to default duration
PasswordExpirationPolicies = "passwordExpirationPolicies", // handle downgrade by setting to default duration
AutoProvisioning = "autoProvisioning", // handle downgrade by disabling auto provisioning
SshPam = "sshPam",
FullRbac = "fullRbac",
SiteProvisioningKeys = "siteProvisioningKeys" // handle downgrade by revoking keys if needed
}
export const tierMatrix: Record<TierFeature, Tier[]> = {
[TierFeature.OrgOidc]: ["tier1", "tier2", "tier3", "enterprise"],
[TierFeature.LoginPageDomain]: ["tier1", "tier2", "tier3", "enterprise"],
[TierFeature.DeviceApprovals]: ["tier1", "tier3", "enterprise"],
[TierFeature.LoginPageBranding]: ["tier1", "tier3", "enterprise"],
[TierFeature.LogExport]: ["tier3", "enterprise"],
[TierFeature.AccessLogs]: ["tier2", "tier3", "enterprise"],
[TierFeature.ActionLogs]: ["tier2", "tier3", "enterprise"],
[TierFeature.ConnectionLogs]: ["tier2", "tier3", "enterprise"],
[TierFeature.RotateCredentials]: ["tier1", "tier2", "tier3", "enterprise"],
[TierFeature.MaintencePage]: ["tier1", "tier2", "tier3", "enterprise"],
[TierFeature.DevicePosture]: ["tier2", "tier3", "enterprise"],
[TierFeature.TwoFactorEnforcement]: [
"tier1",
"tier2",
"tier3",
"enterprise"
],
[TierFeature.SessionDurationPolicies]: [
"tier1",
"tier2",
"tier3",
"enterprise"
],
[TierFeature.PasswordExpirationPolicies]: [
"tier1",
"tier2",
"tier3",
"enterprise"
],
[TierFeature.AutoProvisioning]: ["tier1", "tier3", "enterprise"],
[TierFeature.SshPam]: ["tier1", "tier3", "enterprise"],
[TierFeature.FullRbac]: ["tier1", "tier2", "tier3", "enterprise"],
[TierFeature.SiteProvisioningKeys]: ["enterprise"]
};

View File

@@ -1,34 +0,0 @@
export enum TierId {
STANDARD = "standard"
}
export type TierPriceSet = {
[key in TierId]: string;
};
export const tierPriceSet: TierPriceSet = {
// Free tier matches the freeLimitSet
[TierId.STANDARD]: "price_1RrQ9cD3Ee2Ir7Wmqdy3KBa0"
};
export const tierPriceSetSandbox: TierPriceSet = {
// Free tier matches the freeLimitSet
// when matching tier the keys closer to 0 index are matched first so list the tiers in descending order of value
[TierId.STANDARD]: "price_1RrAYJDCpkOb237By2s1P32m"
};
export function getTierPriceSet(
environment?: string,
sandbox_mode?: boolean
): TierPriceSet {
if (
(process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true") ||
(environment === "prod" && sandbox_mode !== true)
) {
// THIS GETS LOADED CLIENT SIDE AND SERVER SIDE
return tierPriceSet;
} else {
return tierPriceSetSandbox;
}
}

View File

@@ -1,74 +1,32 @@
import { eq, sql, and } from "drizzle-orm";
import { v4 as uuidv4 } from "uuid";
import { PutObjectCommand } from "@aws-sdk/client-s3";
import * as fs from "fs/promises";
import * as path from "path";
import {
db,
usage,
customers,
sites,
newts,
limits,
Usage,
Limit,
Transaction
Transaction,
orgs
} from "@server/db";
import { FeatureId, getFeatureMeterId } from "./features";
import logger from "@server/logger";
import { sendToClient } from "#dynamic/routers/ws";
import { build } from "@server/build";
import { s3Client } from "@server/lib/s3";
import cache from "@server/lib/cache";
interface StripeEvent {
identifier?: string;
timestamp: number;
event_name: string;
payload: {
value: number;
stripe_customer_id: string;
};
}
import cache from "#dynamic/lib/cache";
export function noop() {
if (
build !== "saas" ||
!process.env.S3_BUCKET ||
!process.env.LOCAL_FILE_PATH
) {
if (build !== "saas") {
return true;
}
return false;
}
export class UsageService {
private bucketName: string | undefined;
private currentEventFile: string | null = null;
private currentFileStartTime: number = 0;
private eventsDir: string | undefined;
private uploadingFiles: Set<string> = new Set();
constructor() {
if (noop()) {
return;
}
// this.bucketName = privateConfig.getRawPrivateConfig().stripe?.s3Bucket;
// this.eventsDir = privateConfig.getRawPrivateConfig().stripe?.localFilePath;
this.bucketName = process.env.S3_BUCKET || undefined;
this.eventsDir = process.env.LOCAL_FILE_PATH || undefined;
// Ensure events directory exists
this.initializeEventsDirectory().then(() => {
this.uploadPendingEventFilesOnStartup();
});
// Periodically check for old event files to upload
setInterval(() => {
this.uploadOldEventFiles().catch((err) => {
logger.error("Error in periodic event file upload:", err);
});
}, 30000); // every 30 seconds
}
/**
@@ -78,85 +36,6 @@ export class UsageService {
return Math.round(value * 100000000000) / 100000000000; // 11 decimal places
}
private async initializeEventsDirectory(): Promise<void> {
if (!this.eventsDir) {
logger.warn(
"Stripe local file path is not configured, skipping events directory initialization."
);
return;
}
try {
await fs.mkdir(this.eventsDir, { recursive: true });
} catch (error) {
logger.error("Failed to create events directory:", error);
}
}
private async uploadPendingEventFilesOnStartup(): Promise<void> {
if (!this.eventsDir || !this.bucketName) {
logger.warn(
"Stripe local file path or bucket name is not configured, skipping leftover event file upload."
);
return;
}
try {
const files = await fs.readdir(this.eventsDir);
for (const file of files) {
if (file.endsWith(".json")) {
const filePath = path.join(this.eventsDir, file);
try {
const fileContent = await fs.readFile(
filePath,
"utf-8"
);
const events = JSON.parse(fileContent);
if (Array.isArray(events) && events.length > 0) {
// Upload to S3
const uploadCommand = new PutObjectCommand({
Bucket: this.bucketName,
Key: file,
Body: fileContent,
ContentType: "application/json"
});
await s3Client.send(uploadCommand);
// Check if file still exists before unlinking
try {
await fs.access(filePath);
await fs.unlink(filePath);
} catch (unlinkError) {
logger.debug(
`Startup file ${file} was already deleted`
);
}
logger.info(
`Uploaded leftover event file ${file} to S3 with ${events.length} events`
);
} else {
// Remove empty file
try {
await fs.access(filePath);
await fs.unlink(filePath);
} catch (unlinkError) {
logger.debug(
`Empty startup file ${file} was already deleted`
);
}
}
} catch (err) {
logger.error(
`Error processing leftover event file ${file}:`,
err
);
}
}
}
} catch (error) {
logger.error("Failed to scan for leftover event files");
}
}
public async add(
orgId: string,
featureId: FeatureId,
@@ -176,28 +55,20 @@ export class UsageService {
while (attempt <= maxRetries) {
try {
// Get subscription data for this org (with caching)
const customerId = await this.getCustomerId(orgId, featureId);
if (!customerId) {
logger.warn(
`No subscription data found for org ${orgId} and feature ${featureId}`
);
return null;
}
let usage;
if (transaction) {
const orgIdToUse = await this.getBillingOrg(orgId, transaction);
usage = await this.internalAddUsage(
orgId,
orgIdToUse,
featureId,
value,
transaction
);
} else {
await db.transaction(async (trx) => {
const orgIdToUse = await this.getBillingOrg(orgId, trx);
usage = await this.internalAddUsage(
orgId,
orgIdToUse,
featureId,
value,
trx
@@ -205,9 +76,6 @@ export class UsageService {
});
}
// Log event for Stripe
await this.logStripeEvent(featureId, value, customerId);
return usage || null;
} catch (error: any) {
// Check if this is a deadlock error
@@ -243,7 +111,7 @@ export class UsageService {
}
private async internalAddUsage(
orgId: string,
orgId: string, // here the orgId is the billing org already resolved by getBillingOrg in updateCount
featureId: FeatureId,
value: number,
trx: Transaction
@@ -262,17 +130,22 @@ export class UsageService {
featureId,
orgId,
meterId,
latestValue: value,
instantaneousValue: value || 0,
latestValue: value || 0,
updatedAt: Math.floor(Date.now() / 1000)
})
.onConflictDoUpdate({
target: usage.usageId,
set: {
latestValue: sql`${usage.latestValue} + ${value}`
instantaneousValue: sql`COALESCE(${usage.instantaneousValue}, 0) + ${value}`
}
})
.returning();
logger.debug(
`Added usage for org ${orgId} feature ${featureId}: +${value}, new instantaneousValue: ${returnUsage.instantaneousValue}`
);
return returnUsage;
}
@@ -286,7 +159,7 @@ export class UsageService {
return new Date(date * 1000).toISOString().split("T")[0];
}
async updateDaily(
async updateCount(
orgId: string,
featureId: FeatureId,
value?: number,
@@ -295,30 +168,20 @@ export class UsageService {
if (noop()) {
return;
}
try {
if (!customerId) {
customerId =
(await this.getCustomerId(orgId, featureId)) || undefined;
if (!customerId) {
logger.warn(
`No subscription data found for org ${orgId} and feature ${featureId}`
);
return;
}
}
const orgIdToUse = await this.getBillingOrg(orgId);
try {
// Truncate value to 11 decimal places if provided
if (value !== undefined && value !== null) {
value = this.truncateValue(value);
}
const today = this.getTodayDateString();
let currentUsage: Usage | null = null;
await db.transaction(async (trx) => {
// Get existing meter record
const usageId = `${orgId}-${featureId}`;
const usageId = `${orgIdToUse}-${featureId}`;
// Get current usage record
[currentUsage] = await trx
.select()
@@ -327,66 +190,34 @@ export class UsageService {
.limit(1);
if (currentUsage) {
const lastUpdateDate = this.getDateString(
currentUsage.updatedAt
);
const currentRunningTotal = currentUsage.latestValue;
const lastDailyValue = currentUsage.instantaneousValue || 0;
if (value == undefined || value === null) {
value = currentUsage.instantaneousValue || 0;
}
if (lastUpdateDate === today) {
// Same day update: replace the daily value
// Remove old daily value from running total, add new value
const newRunningTotal = this.truncateValue(
currentRunningTotal - lastDailyValue + value
);
await trx
.update(usage)
.set({
latestValue: newRunningTotal,
instantaneousValue: value,
updatedAt: Math.floor(Date.now() / 1000)
})
.where(eq(usage.usageId, usageId));
} else {
// New day: add to running total
const newRunningTotal = this.truncateValue(
currentRunningTotal + value
);
await trx
.update(usage)
.set({
latestValue: newRunningTotal,
instantaneousValue: value,
updatedAt: Math.floor(Date.now() / 1000)
})
.where(eq(usage.usageId, usageId));
}
await trx
.update(usage)
.set({
instantaneousValue: value,
updatedAt: Math.floor(Date.now() / 1000)
})
.where(eq(usage.usageId, usageId));
} else {
// First record for this meter
const meterId = getFeatureMeterId(featureId);
const truncatedValue = this.truncateValue(value || 0);
await trx.insert(usage).values({
usageId,
featureId,
orgId,
orgId: orgIdToUse,
meterId,
instantaneousValue: truncatedValue,
latestValue: truncatedValue,
instantaneousValue: value || 0,
latestValue: value || 0,
updatedAt: Math.floor(Date.now() / 1000)
});
}
});
await this.logStripeEvent(featureId, value || 0, customerId);
// if (privateConfig.getRawPrivateConfig().flags.usage_reporting) {
// await this.logStripeEvent(featureId, value || 0, customerId);
// }
} catch (error) {
logger.error(
`Failed to update daily usage for ${orgId}/${featureId}:`,
`Failed to update count usage for ${orgIdToUse}/${featureId}:`,
error
);
}
@@ -396,8 +227,10 @@ export class UsageService {
orgId: string,
featureId: FeatureId
): Promise<string | null> {
const cacheKey = `customer_${orgId}_${featureId}`;
const cached = cache.get<string>(cacheKey);
const orgIdToUse = await this.getBillingOrg(orgId);
const cacheKey = `customer_${orgIdToUse}_${featureId}`;
const cached = await cache.get<string>(cacheKey);
if (cached) {
return cached;
@@ -410,7 +243,7 @@ export class UsageService {
customerId: customers.customerId
})
.from(customers)
.where(eq(customers.orgId, orgId))
.where(eq(customers.orgId, orgIdToUse))
.limit(1);
if (!customer) {
@@ -420,194 +253,18 @@ export class UsageService {
const customerId = customer.customerId;
// Cache the result
cache.set(cacheKey, customerId, 300); // 5 minute TTL
await cache.set(cacheKey, customerId, 300); // 5 minute TTL
return customerId;
} catch (error) {
logger.error(
`Failed to get subscription data for ${orgId}/${featureId}:`,
`Failed to get subscription data for ${orgIdToUse}/${featureId}:`,
error
);
return null;
}
}
private async logStripeEvent(
featureId: FeatureId,
value: number,
customerId: string
): Promise<void> {
// Truncate value to 11 decimal places before sending to Stripe
const truncatedValue = this.truncateValue(value);
const event: StripeEvent = {
identifier: uuidv4(),
timestamp: Math.floor(new Date().getTime() / 1000),
event_name: featureId,
payload: {
value: truncatedValue,
stripe_customer_id: customerId
}
};
await this.writeEventToFile(event);
await this.checkAndUploadFile();
}
private async writeEventToFile(event: StripeEvent): Promise<void> {
if (!this.eventsDir || !this.bucketName) {
logger.warn(
"Stripe local file path or bucket name is not configured, skipping event file write."
);
return;
}
if (!this.currentEventFile) {
this.currentEventFile = this.generateEventFileName();
this.currentFileStartTime = Date.now();
}
const filePath = path.join(this.eventsDir, this.currentEventFile);
try {
let events: StripeEvent[] = [];
// Try to read existing file
try {
const fileContent = await fs.readFile(filePath, "utf-8");
events = JSON.parse(fileContent);
} catch (error) {
// File doesn't exist or is empty, start with empty array
events = [];
}
// Add new event
events.push(event);
// Write back to file
await fs.writeFile(filePath, JSON.stringify(events, null, 2));
} catch (error) {
logger.error("Failed to write event to file:", error);
}
}
private async checkAndUploadFile(): Promise<void> {
if (!this.currentEventFile) {
return;
}
const now = Date.now();
const fileAge = now - this.currentFileStartTime;
// Check if file is at least 1 minute old
if (fileAge >= 60000) {
// 60 seconds
await this.uploadFileToS3();
}
}
private async uploadFileToS3(): Promise<void> {
if (!this.bucketName || !this.eventsDir) {
logger.warn(
"Stripe local file path or bucket name is not configured, skipping S3 upload."
);
return;
}
if (!this.currentEventFile) {
return;
}
const fileName = this.currentEventFile;
const filePath = path.join(this.eventsDir, fileName);
// Check if this file is already being uploaded
if (this.uploadingFiles.has(fileName)) {
logger.debug(
`File ${fileName} is already being uploaded, skipping`
);
return;
}
// Mark file as being uploaded
this.uploadingFiles.add(fileName);
try {
// Check if file exists before trying to read it
try {
await fs.access(filePath);
} catch (error) {
logger.debug(
`File ${fileName} does not exist, may have been already processed`
);
this.uploadingFiles.delete(fileName);
// Reset current file if it was this file
if (this.currentEventFile === fileName) {
this.currentEventFile = null;
this.currentFileStartTime = 0;
}
return;
}
// Check if file exists and has content
const fileContent = await fs.readFile(filePath, "utf-8");
const events = JSON.parse(fileContent);
if (events.length === 0) {
// No events to upload, just clean up
try {
await fs.unlink(filePath);
} catch (unlinkError) {
// File may have been already deleted
logger.debug(
`File ${fileName} was already deleted during cleanup`
);
}
this.currentEventFile = null;
this.uploadingFiles.delete(fileName);
return;
}
// Upload to S3
const uploadCommand = new PutObjectCommand({
Bucket: this.bucketName,
Key: fileName,
Body: fileContent,
ContentType: "application/json"
});
await s3Client.send(uploadCommand);
// Clean up local file - check if it still exists before unlinking
try {
await fs.access(filePath);
await fs.unlink(filePath);
} catch (unlinkError) {
// File may have been already deleted by another process
logger.debug(
`File ${fileName} was already deleted during upload`
);
}
logger.info(
`Uploaded ${fileName} to S3 with ${events.length} events`
);
// Reset for next file
this.currentEventFile = null;
this.currentFileStartTime = 0;
} catch (error) {
logger.error(`Failed to upload ${fileName} to S3:`, error);
} finally {
// Always remove from uploading set
this.uploadingFiles.delete(fileName);
}
}
private generateEventFileName(): string {
const timestamp = new Date().toISOString().replace(/[:.]/g, "-");
const uuid = uuidv4().substring(0, 8);
return `events-${timestamp}-${uuid}.json`;
}
public async getUsage(
orgId: string,
featureId: FeatureId,
@@ -617,7 +274,9 @@ export class UsageService {
return null;
}
const usageId = `${orgId}-${featureId}`;
const orgIdToUse = await this.getBillingOrg(orgId, trx);
const usageId = `${orgIdToUse}-${featureId}`;
try {
const [result] = await trx
@@ -629,7 +288,7 @@ export class UsageService {
if (!result) {
// Lets create one if it doesn't exist using upsert to handle race conditions
logger.info(
`Creating new usage record for ${orgId}/${featureId}`
`Creating new usage record for ${orgIdToUse}/${featureId}`
);
const meterId = getFeatureMeterId(featureId);
@@ -639,7 +298,7 @@ export class UsageService {
.values({
usageId,
featureId,
orgId,
orgId: orgIdToUse,
meterId,
latestValue: 0,
updatedAt: Math.floor(Date.now() / 1000)
@@ -661,7 +320,7 @@ export class UsageService {
} catch (insertError) {
// Fallback: try to fetch existing record in case of any insert issues
logger.warn(
`Insert failed for ${orgId}/${featureId}, attempting to fetch existing record:`,
`Insert failed for ${orgIdToUse}/${featureId}, attempting to fetch existing record:`,
insertError
);
const [existingUsage] = await trx
@@ -676,136 +335,45 @@ export class UsageService {
return result;
} catch (error) {
logger.error(
`Failed to get usage for ${orgId}/${featureId}:`,
`Failed to get usage for ${orgIdToUse}/${featureId}:`,
error
);
throw error;
}
}
public async getUsageDaily(
public async getBillingOrg(
orgId: string,
featureId: FeatureId
): Promise<Usage | null> {
if (noop()) {
return null;
trx: Transaction | typeof db = db
): Promise<string> {
let orgIdToUse = orgId;
// get the org
const [org] = await trx
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
throw new Error(`Organization with ID ${orgId} not found`);
}
await this.updateDaily(orgId, featureId); // Ensure daily usage is updated
return this.getUsage(orgId, featureId);
}
public async forceUpload(): Promise<void> {
await this.uploadFileToS3();
}
/**
* Scan the events directory for files older than 1 minute and upload them if not empty.
*/
private async uploadOldEventFiles(): Promise<void> {
if (!this.eventsDir || !this.bucketName) {
logger.warn(
"Stripe local file path or bucket name is not configured, skipping old event file upload."
);
return;
}
try {
const files = await fs.readdir(this.eventsDir);
const now = Date.now();
for (const file of files) {
if (!file.endsWith(".json")) continue;
// Skip files that are already being uploaded
if (this.uploadingFiles.has(file)) {
logger.debug(
`Skipping file ${file} as it's already being uploaded`
);
continue;
}
const filePath = path.join(this.eventsDir, file);
try {
// Check if file still exists before processing
try {
await fs.access(filePath);
} catch (accessError) {
logger.debug(`File ${file} does not exist, skipping`);
continue;
}
const stat = await fs.stat(filePath);
const age = now - stat.mtimeMs;
if (age >= 90000) {
// 1.5 minutes - Mark as being uploaded
this.uploadingFiles.add(file);
try {
const fileContent = await fs.readFile(
filePath,
"utf-8"
);
const events = JSON.parse(fileContent);
if (Array.isArray(events) && events.length > 0) {
// Upload to S3
const uploadCommand = new PutObjectCommand({
Bucket: this.bucketName,
Key: file,
Body: fileContent,
ContentType: "application/json"
});
await s3Client.send(uploadCommand);
// Check if file still exists before unlinking
try {
await fs.access(filePath);
await fs.unlink(filePath);
} catch (unlinkError) {
logger.debug(
`File ${file} was already deleted during interval upload`
);
}
logger.info(
`Interval: Uploaded event file ${file} to S3 with ${events.length} events`
);
// If this was the current event file, reset it
if (this.currentEventFile === file) {
this.currentEventFile = null;
this.currentFileStartTime = 0;
}
} else {
// Remove empty file
try {
await fs.access(filePath);
await fs.unlink(filePath);
} catch (unlinkError) {
logger.debug(
`Empty file ${file} was already deleted`
);
}
}
} finally {
// Always remove from uploading set
this.uploadingFiles.delete(file);
}
}
} catch (err) {
logger.error(
`Interval: Error processing event file ${file}:`,
err
);
// Remove from uploading set on error
this.uploadingFiles.delete(file);
}
if (!org.isBillingOrg) {
if (org.billingOrgId) {
orgIdToUse = org.billingOrgId;
} else {
throw new Error(
`Organization ${orgId} is not a billing org and does not have a billingOrgId set`
);
}
} catch (err) {
logger.error("Interval: Failed to scan for event files:", err);
}
return orgIdToUse;
}
public async checkLimitSet(
orgId: string,
kickSites = false,
featureId?: FeatureId,
usage?: Usage,
trx: Transaction | typeof db = db
@@ -813,6 +381,9 @@ export class UsageService {
if (noop()) {
return false;
}
const orgIdToUse = await this.getBillingOrg(orgId, trx);
// This method should check the current usage against the limits set for the organization
// and kick out all of the sites on the org
let hasExceededLimits = false;
@@ -826,7 +397,7 @@ export class UsageService {
.from(limits)
.where(
and(
eq(limits.orgId, orgId),
eq(limits.orgId, orgIdToUse),
eq(limits.featureId, featureId)
)
);
@@ -835,11 +406,11 @@ export class UsageService {
orgLimits = await trx
.select()
.from(limits)
.where(eq(limits.orgId, orgId));
.where(eq(limits.orgId, orgIdToUse));
}
if (orgLimits.length === 0) {
logger.debug(`No limits set for org ${orgId}`);
logger.debug(`No limits set for org ${orgIdToUse}`);
return false;
}
@@ -850,7 +421,7 @@ export class UsageService {
currentUsage = usage;
} else {
currentUsage = await this.getUsage(
orgId,
orgIdToUse,
limit.featureId as FeatureId,
trx
);
@@ -861,10 +432,10 @@ export class UsageService {
currentUsage?.latestValue ||
0;
logger.debug(
`Current usage for org ${orgId} on feature ${limit.featureId}: ${usageValue}`
`Current usage for org ${orgIdToUse} on feature ${limit.featureId}: ${usageValue}`
);
logger.debug(
`Limit for org ${orgId} on feature ${limit.featureId}: ${limit.value}`
`Limit for org ${orgIdToUse} on feature ${limit.featureId}: ${limit.value}`
);
if (
currentUsage &&
@@ -872,67 +443,15 @@ export class UsageService {
usageValue > limit.value
) {
logger.debug(
`Org ${orgId} has exceeded limit for ${limit.featureId}: ` +
`Org ${orgIdToUse} has exceeded limit for ${limit.featureId}: ` +
`${usageValue} > ${limit.value}`
);
hasExceededLimits = true;
break; // Exit early if any limit is exceeded
}
}
// If any limits are exceeded, disconnect all sites for this organization
if (hasExceededLimits && kickSites) {
logger.warn(
`Disconnecting all sites for org ${orgId} due to exceeded limits`
);
// Get all sites for this organization
const orgSites = await trx
.select()
.from(sites)
.where(eq(sites.orgId, orgId));
// Mark all sites as offline and send termination messages
const siteUpdates = orgSites.map((site) => site.siteId);
if (siteUpdates.length > 0) {
// Send termination messages to newt sites
for (const site of orgSites) {
if (site.type === "newt") {
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (newt) {
const payload = {
type: `newt/wg/terminate`,
data: {
reason: "Usage limits exceeded"
}
};
// Don't await to prevent blocking
await sendToClient(newt.newtId, payload).catch(
(error: any) => {
logger.error(
`Failed to send termination message to newt ${newt.newtId}:`,
error
);
}
);
}
}
}
logger.info(
`Disconnected ${orgSites.length} sites for org ${orgId} due to exceeded limits`
);
}
}
} catch (error) {
logger.error(`Error checking limits for org ${orgId}:`, error);
logger.error(`Error checking limits for org ${orgIdToUse}:`, error);
}
return hasExceededLimits;

View File

@@ -0,0 +1,3 @@
import { z } from "zod";
export const MaintenanceSchema = z.object({});

View File

@@ -1,4 +1,14 @@
import { db, newts, blueprints, Blueprint } from "@server/db";
import {
db,
newts,
blueprints,
Blueprint,
Site,
siteResources,
roleSiteResources,
userSiteResources,
clientSiteResources
} from "@server/db";
import { Config, ConfigSchema } from "./types";
import { ProxyResourcesResults, updateProxyResources } from "./proxyResources";
import { fromError } from "zod-validation-error";
@@ -15,6 +25,7 @@ import { BlueprintSource } from "@server/routers/blueprints/types";
import { stringify as stringifyYaml } from "yaml";
import { faker } from "@faker-js/faker";
import { handleMessagingForUpdatedSiteResource } from "@server/routers/siteResource";
import { rebuildClientAssociationsFromSiteResource } from "../rebuildClientAssociations";
type ApplyBlueprintArgs = {
orgId: string;
@@ -96,7 +107,7 @@ export async function applyBlueprint({
[target],
matchingHealthcheck ? [matchingHealthcheck] : [],
result.proxyResource.protocol,
result.proxyResource.proxyPort
site.newt.version
);
}
}
@@ -108,38 +119,145 @@ export async function applyBlueprint({
// We need to update the targets on the newts from the successfully updated information
for (const result of clientResourcesResults) {
const [site] = await trx
.select()
.from(sites)
.innerJoin(newts, eq(sites.siteId, newts.siteId))
.where(
and(
eq(sites.siteId, result.newSiteResource.siteId),
eq(sites.orgId, orgId),
eq(sites.type, "newt"),
isNotNull(sites.pubKey)
if (
result.oldSiteResource &&
result.oldSiteResource.siteId !=
result.newSiteResource.siteId
) {
// query existing associations
const existingRoleIds = await trx
.select()
.from(roleSiteResources)
.where(
eq(
roleSiteResources.siteResourceId,
result.oldSiteResource.siteResourceId
)
)
)
.limit(1);
.then((rows) => rows.map((row) => row.roleId));
if (!site) {
logger.debug(
`No newt site found for client resource ${result.newSiteResource.siteResourceId}, skipping target update`
const existingUserIds = await trx
.select()
.from(userSiteResources)
.where(
eq(
userSiteResources.siteResourceId,
result.oldSiteResource.siteResourceId
)
)
.then((rows) => rows.map((row) => row.userId));
const existingClientIds = await trx
.select()
.from(clientSiteResources)
.where(
eq(
clientSiteResources.siteResourceId,
result.oldSiteResource.siteResourceId
)
)
.then((rows) => rows.map((row) => row.clientId));
// delete the existing site resource
await trx
.delete(siteResources)
.where(
and(
eq(
siteResources.siteResourceId,
result.oldSiteResource.siteResourceId
)
)
);
await rebuildClientAssociationsFromSiteResource(
result.oldSiteResource,
trx
);
const [insertedSiteResource] = await trx
.insert(siteResources)
.values({
...result.newSiteResource
})
.returning();
// wait some time to allow for messages to be handled
await new Promise((resolve) => setTimeout(resolve, 750));
//////////////////// update the associations ////////////////////
if (existingRoleIds.length > 0) {
await trx.insert(roleSiteResources).values(
existingRoleIds.map((roleId) => ({
roleId,
siteResourceId:
insertedSiteResource!.siteResourceId
}))
);
}
if (existingUserIds.length > 0) {
await trx.insert(userSiteResources).values(
existingUserIds.map((userId) => ({
userId,
siteResourceId:
insertedSiteResource!.siteResourceId
}))
);
}
if (existingClientIds.length > 0) {
await trx.insert(clientSiteResources).values(
existingClientIds.map((clientId) => ({
clientId,
siteResourceId:
insertedSiteResource!.siteResourceId
}))
);
}
await rebuildClientAssociationsFromSiteResource(
insertedSiteResource,
trx
);
} else {
const [newSite] = await trx
.select()
.from(sites)
.innerJoin(newts, eq(sites.siteId, newts.siteId))
.where(
and(
eq(sites.siteId, result.newSiteResource.siteId),
eq(sites.orgId, orgId),
eq(sites.type, "newt"),
isNotNull(sites.pubKey)
)
)
.limit(1);
if (!newSite) {
logger.debug(
`No newt site found for client resource ${result.newSiteResource.siteResourceId}, skipping target update`
);
continue;
}
logger.debug(
`Updating client resource ${result.newSiteResource.siteResourceId} on site ${newSite.sites.siteId}`
);
await handleMessagingForUpdatedSiteResource(
result.oldSiteResource,
result.newSiteResource,
{
siteId: newSite.sites.siteId,
orgId: newSite.sites.orgId
},
trx
);
continue;
}
logger.debug(
`Updating client resource ${result.newSiteResource.siteResourceId} on site ${site.sites.siteId}`
);
await handleMessagingForUpdatedSiteResource(
result.oldSiteResource,
result.newSiteResource,
{ siteId: site.sites.siteId, orgId: site.sites.orgId },
trx
);
// await addClientTargets(
// site.newt.newtId,
// result.resource.destination,

View File

@@ -36,7 +36,9 @@ export async function applyNewtDockerBlueprint(
if (
isEmptyObject(blueprint["proxy-resources"]) &&
isEmptyObject(blueprint["client-resources"])
isEmptyObject(blueprint["client-resources"]) &&
isEmptyObject(blueprint["public-resources"]) &&
isEmptyObject(blueprint["private-resources"])
) {
return;
}

View File

@@ -11,9 +11,10 @@ import {
userSiteResources
} from "@server/db";
import { sites } from "@server/db";
import { eq, and, ne, inArray } from "drizzle-orm";
import { eq, and, ne, inArray, or } from "drizzle-orm";
import { Config } from "./types";
import logger from "@server/logger";
import { getNextAvailableAliasAddress } from "../ip";
export type ClientResourcesResults = {
newSiteResource: SiteResource;
@@ -75,22 +76,20 @@ export async function updateClientResources(
}
if (existingResource) {
if (existingResource.siteId !== site.siteId) {
throw new Error(
`You can not change the site of an existing client resource (${resourceNiceId}). Please delete and recreate it instead.`
);
}
// Update existing resource
const [updatedResource] = await trx
.update(siteResources)
.set({
name: resourceData.name || resourceNiceId,
siteId: site.siteId,
mode: resourceData.mode,
destination: resourceData.destination,
enabled: true, // hardcoded for now
// enabled: resourceData.enabled ?? true,
alias: resourceData.alias || null
alias: resourceData.alias || null,
disableIcmp: resourceData["disable-icmp"],
tcpPortRangeString: resourceData["tcp-ports"],
udpPortRangeString: resourceData["udp-ports"]
})
.where(
eq(
@@ -143,7 +142,10 @@ export async function updateClientResources(
.innerJoin(userOrgs, eq(users.userId, userOrgs.userId))
.where(
and(
inArray(users.username, resourceData.users),
or(
inArray(users.username, resourceData.users),
inArray(users.email, resourceData.users)
),
eq(userOrgs.orgId, orgId)
)
);
@@ -205,6 +207,12 @@ export async function updateClientResources(
oldSiteResource: existingResource
});
} else {
let aliasAddress: string | null = null;
if (resourceData.mode == "host") {
// we can only have an alias on a host
aliasAddress = await getNextAvailableAliasAddress(orgId);
}
// Create new resource
const [newResource] = await trx
.insert(siteResources)
@@ -217,7 +225,11 @@ export async function updateClientResources(
destination: resourceData.destination,
enabled: true, // hardcoded for now
// enabled: resourceData.enabled ?? true,
alias: resourceData.alias || null
alias: resourceData.alias || null,
aliasAddress: aliasAddress,
disableIcmp: resourceData["disable-icmp"],
tcpPortRangeString: resourceData["tcp-ports"],
udpPortRangeString: resourceData["udp-ports"]
})
.returning();
@@ -267,7 +279,10 @@ export async function updateClientResources(
.innerJoin(userOrgs, eq(users.userId, userOrgs.userId))
.where(
and(
inArray(users.username, resourceData.users),
or(
inArray(users.username, resourceData.users),
inArray(users.email, resourceData.users)
),
eq(userOrgs.orgId, orgId)
)
);

View File

@@ -54,10 +54,14 @@ function getContainerPort(container: Container): number | null {
export function processContainerLabels(containers: Container[]): {
"proxy-resources": { [key: string]: ResourceConfig };
"client-resources": { [key: string]: ResourceConfig };
"public-resources": { [key: string]: ResourceConfig };
"private-resources": { [key: string]: ResourceConfig };
} {
const result = {
"proxy-resources": {} as { [key: string]: ResourceConfig },
"client-resources": {} as { [key: string]: ResourceConfig }
"client-resources": {} as { [key: string]: ResourceConfig },
"public-resources": {} as { [key: string]: ResourceConfig },
"private-resources": {} as { [key: string]: ResourceConfig }
};
// Process each container
@@ -68,8 +72,10 @@ export function processContainerLabels(containers: Container[]): {
const proxyResourceLabels: DockerLabels = {};
const clientResourceLabels: DockerLabels = {};
const publicResourceLabels: DockerLabels = {};
const privateResourceLabels: DockerLabels = {};
// Filter and separate proxy-resources and client-resources labels
// Filter and separate proxy-resources, client-resources, public-resources, and private-resources labels
Object.entries(container.labels).forEach(([key, value]) => {
if (key.startsWith("pangolin.proxy-resources.")) {
// remove the pangolin.proxy- prefix to get "resources.xxx"
@@ -79,6 +85,14 @@ export function processContainerLabels(containers: Container[]): {
// remove the pangolin.client- prefix to get "resources.xxx"
const strippedKey = key.replace("pangolin.client-", "");
clientResourceLabels[strippedKey] = value;
} else if (key.startsWith("pangolin.public-resources.")) {
// remove the pangolin.public- prefix to get "resources.xxx"
const strippedKey = key.replace("pangolin.public-", "");
publicResourceLabels[strippedKey] = value;
} else if (key.startsWith("pangolin.private-resources.")) {
// remove the pangolin.private- prefix to get "resources.xxx"
const strippedKey = key.replace("pangolin.private-", "");
privateResourceLabels[strippedKey] = value;
}
});
@@ -99,6 +113,24 @@ export function processContainerLabels(containers: Container[]): {
result["client-resources"]
);
}
// Process public resources (alias for proxy resources)
if (Object.keys(publicResourceLabels).length > 0) {
processResourceLabels(
publicResourceLabels,
container,
result["public-resources"]
);
}
// Process private resources (alias for client resources)
if (Object.keys(privateResourceLabels).length > 0) {
processResourceLabels(
privateResourceLabels,
container,
result["private-resources"]
);
}
});
return result;

View File

@@ -3,6 +3,7 @@ import {
orgDomains,
Resource,
resourceHeaderAuth,
resourceHeaderAuthExtendedCompatibility,
resourcePincode,
resourceRules,
resourceWhitelist,
@@ -30,7 +31,8 @@ import { pickPort } from "@server/routers/target/helpers";
import { resourcePassword } from "@server/db";
import { hashPassword } from "@server/auth/password";
import { isValidCIDR, isValidIP, isValidUrlGlobPattern } from "../validators";
import { get } from "http";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { tierMatrix } from "../billing/tierMatrix";
export type ProxyResourcesResults = {
proxyResource: Resource;
@@ -209,6 +211,15 @@ export async function updateProxyResources(
resource = existingResource;
} else {
// Update existing resource
const isLicensed = await isLicensedOrSubscribed(
orgId,
tierMatrix.maintencePage
);
if (!isLicensed) {
resourceData.maintenance = undefined;
}
[resource] = await trx
.update(resources)
.set({
@@ -233,7 +244,14 @@ export async function updateProxyResources(
: false,
headers: headers || null,
applyRules:
resourceData.rules && resourceData.rules.length > 0
resourceData.rules && resourceData.rules.length > 0,
maintenanceModeEnabled:
resourceData.maintenance?.enabled,
maintenanceModeType: resourceData.maintenance?.type,
maintenanceTitle: resourceData.maintenance?.title,
maintenanceMessage: resourceData.maintenance?.message,
maintenanceEstimatedTime:
resourceData.maintenance?.["estimated-time"]
})
.where(
eq(resources.resourceId, existingResource.resourceId)
@@ -287,21 +305,47 @@ export async function updateProxyResources(
existingResource.resourceId
)
);
await trx
.delete(resourceHeaderAuthExtendedCompatibility)
.where(
eq(
resourceHeaderAuthExtendedCompatibility.resourceId,
existingResource.resourceId
)
);
if (resourceData.auth?.["basic-auth"]) {
const headerAuthUser =
resourceData.auth?.["basic-auth"]?.user;
const headerAuthPassword =
resourceData.auth?.["basic-auth"]?.password;
if (headerAuthUser && headerAuthPassword) {
const headerAuthExtendedCompatibility =
resourceData.auth?.["basic-auth"]
?.extendedCompatibility;
if (
headerAuthUser &&
headerAuthPassword &&
headerAuthExtendedCompatibility !== null
) {
const headerAuthHash = await hashPassword(
Buffer.from(
`${headerAuthUser}:${headerAuthPassword}`
).toString("base64")
);
await trx.insert(resourceHeaderAuth).values({
resourceId: existingResource.resourceId,
headerAuthHash
});
await Promise.all([
trx.insert(resourceHeaderAuth).values({
resourceId: existingResource.resourceId,
headerAuthHash
}),
trx
.insert(resourceHeaderAuthExtendedCompatibility)
.values({
resourceId: existingResource.resourceId,
extendedCompatibilityIsActivated:
headerAuthExtendedCompatibility
})
]);
}
}
@@ -542,13 +586,18 @@ export async function updateProxyResources(
// Sync rules
for (const [index, rule] of resourceData.rules?.entries() || []) {
const intendedPriority = rule.priority ?? index + 1;
const existingRule = existingRules[index];
if (existingRule) {
if (
existingRule.action !== getRuleAction(rule.action) ||
existingRule.match !== rule.match.toUpperCase() ||
existingRule.value !==
getRuleValue(rule.match.toUpperCase(), rule.value)
getRuleValue(
rule.match.toUpperCase(),
rule.value
) ||
existingRule.priority !== intendedPriority
) {
validateRule(rule);
await trx
@@ -559,7 +608,8 @@ export async function updateProxyResources(
value: getRuleValue(
rule.match.toUpperCase(),
rule.value
)
),
priority: intendedPriority
})
.where(
eq(resourceRules.ruleId, existingRule.ruleId)
@@ -575,7 +625,7 @@ export async function updateProxyResources(
rule.match.toUpperCase(),
rule.value
),
priority: index + 1 // start priorities at 1
priority: intendedPriority
});
}
}
@@ -604,6 +654,14 @@ export async function updateProxyResources(
);
}
const isLicensed = await isLicensedOrSubscribed(
orgId,
tierMatrix.maintencePage
);
if (!isLicensed) {
resourceData.maintenance = undefined;
}
// Create new resource
const [newResource] = await trx
.insert(resources)
@@ -625,7 +683,13 @@ export async function updateProxyResources(
ssl: resourceSsl,
headers: headers || null,
applyRules:
resourceData.rules && resourceData.rules.length > 0
resourceData.rules && resourceData.rules.length > 0,
maintenanceModeEnabled: resourceData.maintenance?.enabled,
maintenanceModeType: resourceData.maintenance?.type,
maintenanceTitle: resourceData.maintenance?.title,
maintenanceMessage: resourceData.maintenance?.message,
maintenanceEstimatedTime:
resourceData.maintenance?.["estimated-time"]
})
.returning();
@@ -656,18 +720,33 @@ export async function updateProxyResources(
const headerAuthUser = resourceData.auth?.["basic-auth"]?.user;
const headerAuthPassword =
resourceData.auth?.["basic-auth"]?.password;
const headerAuthExtendedCompatibility =
resourceData.auth?.["basic-auth"]?.extendedCompatibility;
if (headerAuthUser && headerAuthPassword) {
if (
headerAuthUser &&
headerAuthPassword &&
headerAuthExtendedCompatibility !== null
) {
const headerAuthHash = await hashPassword(
Buffer.from(
`${headerAuthUser}:${headerAuthPassword}`
).toString("base64")
);
await trx.insert(resourceHeaderAuth).values({
resourceId: newResource.resourceId,
headerAuthHash
});
await Promise.all([
trx.insert(resourceHeaderAuth).values({
resourceId: newResource.resourceId,
headerAuthHash
}),
trx
.insert(resourceHeaderAuthExtendedCompatibility)
.values({
resourceId: newResource.resourceId,
extendedCompatibilityIsActivated:
headerAuthExtendedCompatibility
})
]);
}
}
@@ -734,7 +813,7 @@ export async function updateProxyResources(
action: getRuleAction(rule.action),
match: rule.match.toUpperCase(),
value: getRuleValue(rule.match.toUpperCase(), rule.value),
priority: index + 1 // start priorities at 1
priority: rule.priority ?? index + 1
});
}
@@ -865,7 +944,12 @@ async function syncUserResources(
.select()
.from(users)
.innerJoin(userOrgs, eq(users.userId, userOrgs.userId))
.where(and(eq(users.username, username), eq(userOrgs.orgId, orgId)))
.where(
and(
or(eq(users.username, username), eq(users.email, username)),
eq(userOrgs.orgId, orgId)
)
)
.limit(1);
if (!user) {

View File

@@ -1,4 +1,6 @@
import { z } from "zod";
import { portRangeStringSchema } from "@server/lib/ip";
import { MaintenanceSchema } from "#dynamic/lib/blueprints/MaintenanceSchema";
export const SiteSchema = z.object({
name: z.string().min(1).max(100),
@@ -55,7 +57,8 @@ export const AuthSchema = z.object({
"basic-auth": z
.object({
user: z.string().min(1),
password: z.string().min(1)
password: z.string().min(1),
extendedCompatibility: z.boolean().default(true)
})
.optional(),
"sso-enabled": z.boolean().optional().default(false),
@@ -66,7 +69,7 @@ export const AuthSchema = z.object({
.refine((roles) => !roles.includes("Admin"), {
error: "Admin role cannot be included in sso-roles"
}),
"sso-users": z.array(z.email()).optional().default([]),
"sso-users": z.array(z.string()).optional().default([]),
"whitelist-users": z.array(z.email()).optional().default([]),
"auto-login-idp": z.int().positive().optional()
});
@@ -75,7 +78,8 @@ export const RuleSchema = z
.object({
action: z.enum(["allow", "deny", "pass"]),
match: z.enum(["cidr", "path", "ip", "country", "asn"]),
value: z.string()
value: z.string(),
priority: z.int().optional()
})
.refine(
(rule) => {
@@ -108,32 +112,30 @@ export const RuleSchema = z
.refine(
(rule) => {
if (rule.match === "country") {
// Check if it's a valid 2-letter country code
return /^[A-Z]{2}$/.test(rule.value);
// Check if it's a valid 2-letter country code or "ALL"
return /^[A-Z]{2}$/.test(rule.value) || rule.value === "ALL";
}
return true;
},
{
path: ["value"],
message:
"Value must be a 2-letter country code when match is 'country'"
"Value must be a 2-letter country code or 'ALL' when match is 'country'"
}
)
.refine(
(rule) => {
if (rule.match === "asn") {
// Check if it's either AS<number> format or just a number
// Check if it's either AS<number> format or "ALL"
const asNumberPattern = /^AS\d+$/i;
const isASFormat = asNumberPattern.test(rule.value);
const isNumeric = /^\d+$/.test(rule.value);
return isASFormat || isNumeric;
return asNumberPattern.test(rule.value) || rule.value === "ALL";
}
return true;
},
{
path: ["value"],
message:
"Value must be either 'AS<number>' format or a number when match is 'asn'"
"Value must be 'AS<number>' format or 'ALL' when match is 'asn'"
}
);
@@ -156,7 +158,8 @@ export const ResourceSchema = z
"host-header": z.string().optional(),
"tls-server-name": z.string().optional(),
headers: z.array(HeaderSchema).optional(),
rules: z.array(RuleSchema).optional()
rules: z.array(RuleSchema).optional(),
maintenance: MaintenanceSchema.optional()
})
.refine(
(resource) => {
@@ -266,6 +269,39 @@ export const ResourceSchema = z
path: ["auth"],
error: "When protocol is 'tcp' or 'udp', 'auth' must not be provided"
}
)
.refine(
(resource) => {
// Skip validation for targets-only resources
if (isTargetsOnlyResource(resource)) {
return true;
}
// Skip validation if no rules are defined
if (!resource.rules || resource.rules.length === 0) return true;
const finalPriorities: number[] = [];
let priorityCounter = 1;
// Gather priorities, assigning auto-priorities where needed
// following the logic from the backend implementation where
// empty priorities are auto-assigned a value of 1 + index of rule
for (const rule of resource.rules) {
if (rule.priority !== undefined) {
finalPriorities.push(rule.priority);
} else {
finalPriorities.push(priorityCounter);
}
priorityCounter++;
}
// Validate for duplicate priorities
return finalPriorities.length === new Set(finalPriorities).size;
},
{
path: ["rules"],
message:
"Rules have conflicting or invalid priorities (must be unique, including auto-assigned ones)"
}
);
export function isTargetsOnlyResource(resource: any): boolean {
@@ -282,11 +318,14 @@ export const ClientResourceSchema = z
// destinationPort: z.int().positive().optional(),
destination: z.string().min(1),
// enabled: z.boolean().default(true),
"tcp-ports": portRangeStringSchema.optional().default("*"),
"udp-ports": portRangeStringSchema.optional().default("*"),
"disable-icmp": z.boolean().optional().default(false),
alias: z
.string()
.regex(
/^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$/,
"Alias must be a fully qualified domain name (e.g., example.com)"
/^(?:[a-zA-Z0-9*?](?:[a-zA-Z0-9*?-]{0,61}[a-zA-Z0-9*?])?\.)+[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$/,
"Alias must be a fully qualified domain name with optional wildcards (e.g., example.com, *.example.com, host-0?.example.internal)"
)
.optional(),
roles: z
@@ -296,7 +335,7 @@ export const ClientResourceSchema = z
.refine((roles) => !roles.includes("Admin"), {
error: "Admin role cannot be included in roles"
}),
users: z.array(z.email()).optional().default([]),
users: z.array(z.string()).optional().default([]),
machines: z.array(z.string()).optional().default([])
})
.refine(

View File

@@ -1,5 +1,161 @@
import NodeCache from "node-cache";
import logger from "@server/logger";
export const cache = new NodeCache({ stdTTL: 3600, checkperiod: 120 });
// Create local cache with maxKeys limit to prevent memory leaks
// With ~10k requests/day and 5min TTL, 10k keys should be more than sufficient
export const localCache = new NodeCache({
stdTTL: 3600,
checkperiod: 120,
maxKeys: 10000
});
// Log cache statistics periodically for monitoring
setInterval(() => {
const stats = localCache.getStats();
logger.debug(
`Local cache stats - Keys: ${stats.keys}, Hits: ${stats.hits}, Misses: ${stats.misses}, Hit rate: ${stats.hits > 0 ? ((stats.hits / (stats.hits + stats.misses)) * 100).toFixed(2) : 0}%`
);
}, 300000); // Every 5 minutes
/**
* Adaptive cache that uses Redis when available in multi-node environments,
* otherwise falls back to local memory cache for single-node deployments.
*/
class AdaptiveCache {
/**
* Set a value in the cache
* @param key - Cache key
* @param value - Value to cache (will be JSON stringified for Redis)
* @param ttl - Time to live in seconds (0 = no expiration)
* @returns boolean indicating success
*/
async set(key: string, value: any, ttl?: number): Promise<boolean> {
const effectiveTtl = ttl === 0 ? undefined : ttl;
// Use local cache as fallback or primary
const success = localCache.set(key, value, effectiveTtl || 0);
if (success) {
logger.debug(`Set key in local cache: ${key}`);
}
return success;
}
/**
* Get a value from the cache
* @param key - Cache key
* @returns The cached value or undefined if not found
*/
async get<T = any>(key: string): Promise<T | undefined> {
// Use local cache as fallback or primary
const value = localCache.get<T>(key);
if (value !== undefined) {
logger.debug(`Cache hit in local cache: ${key}`);
} else {
logger.debug(`Cache miss in local cache: ${key}`);
}
return value;
}
/**
* Delete a value from the cache
* @param key - Cache key or array of keys
* @returns Number of deleted entries
*/
async del(key: string | string[]): Promise<number> {
const keys = Array.isArray(key) ? key : [key];
let deletedCount = 0;
// Use local cache as fallback or primary
for (const k of keys) {
const success = localCache.del(k);
if (success > 0) {
deletedCount++;
logger.debug(`Deleted key from local cache: ${k}`);
}
}
return deletedCount;
}
/**
* Check if a key exists in the cache
* @param key - Cache key
* @returns boolean indicating if key exists
*/
async has(key: string): Promise<boolean> {
// Use local cache as fallback or primary
return localCache.has(key);
}
/**
* Get multiple values from the cache
* @param keys - Array of cache keys
* @returns Array of values (undefined for missing keys)
*/
async mget<T = any>(keys: string[]): Promise<(T | undefined)[]> {
// Use local cache as fallback or primary
return keys.map((key) => localCache.get<T>(key));
}
/**
* Flush all keys from the cache
*/
async flushAll(): Promise<void> {
localCache.flushAll();
logger.debug("Flushed local cache");
}
/**
* Get cache statistics
* Note: Only returns local cache stats, Redis stats are not included
*/
getStats() {
return localCache.getStats();
}
/**
* Get the current cache backend being used
* @returns "redis" if Redis is available and healthy, "local" otherwise
*/
getCurrentBackend(): "redis" | "local" {
return "local";
}
/**
* Take a key from the cache and delete it
* @param key - Cache key
* @returns The value or undefined if not found
*/
async take<T = any>(key: string): Promise<T | undefined> {
const value = await this.get<T>(key);
if (value !== undefined) {
await this.del(key);
}
return value;
}
/**
* Get TTL (time to live) for a key
* @param key - Cache key
* @returns TTL in seconds, 0 if no expiration, -1 if key doesn't exist
*/
getTtl(key: string): number {
const ttl = localCache.getTtl(key);
if (ttl === undefined) {
return -1;
}
return Math.max(0, Math.floor((ttl - Date.now()) / 1000));
}
/**
* Get all keys from the cache
* Note: Only returns local cache keys, Redis keys are not included
*/
keys(): string[] {
return localCache.keys();
}
}
// Export singleton instance
export const cache = new AdaptiveCache();
export default cache;

View File

@@ -1,21 +1,27 @@
import { listExitNodes } from "#dynamic/lib/exitNodes";
import { build } from "@server/build";
import {
approvals,
clients,
db,
olms,
orgs,
roleClients,
roles,
Transaction,
userClients,
userOrgs,
Transaction
userOrgRoles,
userOrgs
} from "@server/db";
import { eq, and, notInArray } from "drizzle-orm";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
import logger from "@server/logger";
import { rebuildClientAssociationsFromClient } from "./rebuildClientAssociations";
import { sendTerminateClient } from "@server/routers/client/terminate";
import { getUniqueClientName } from "@server/db/names";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import logger from "@server/logger";
import { sendTerminateClient } from "@server/routers/client/terminate";
import { and, eq, notInArray, type InferInsertModel } from "drizzle-orm";
import { rebuildClientAssociationsFromClient } from "./rebuildClientAssociations";
import { OlmErrorCodes } from "@server/routers/olm/error";
import { tierMatrix } from "./billing/tierMatrix";
export async function calculateUserClientsForOrgs(
userId: string,
@@ -34,18 +40,36 @@ export async function calculateUserClientsForOrgs(
return;
}
// Get all user orgs
const allUserOrgs = await transaction
// Get all user orgs with all roles (for org list and role-based logic)
const userOrgRoleRows = await transaction
.select()
.from(userOrgs)
.innerJoin(
userOrgRoles,
and(
eq(userOrgs.userId, userOrgRoles.userId),
eq(userOrgs.orgId, userOrgRoles.orgId)
)
)
.innerJoin(roles, eq(userOrgRoles.roleId, roles.roleId))
.where(eq(userOrgs.userId, userId));
const userOrgIds = allUserOrgs.map((uo) => uo.orgId);
const userOrgIds = [...new Set(userOrgRoleRows.map((r) => r.userOrgs.orgId))];
const orgIdToRoleRows = new Map<
string,
(typeof userOrgRoleRows)[0][]
>();
for (const r of userOrgRoleRows) {
const list = orgIdToRoleRows.get(r.userOrgs.orgId) ?? [];
list.push(r);
orgIdToRoleRows.set(r.userOrgs.orgId, list);
}
// For each OLM, ensure there's a client in each org the user is in
for (const olm of userOlms) {
for (const userOrg of allUserOrgs) {
const orgId = userOrg.orgId;
for (const orgId of orgIdToRoleRows.keys()) {
const roleRowsForOrg = orgIdToRoleRows.get(orgId)!;
const userOrg = roleRowsForOrg[0].userOrgs;
const [org] = await transaction
.select()
@@ -182,21 +206,47 @@ export async function calculateUserClientsForOrgs(
const niceId = await getUniqueClientName(orgId);
const isOrgLicensed = await isLicensedOrSubscribed(
userOrg.orgId,
tierMatrix.deviceApprovals
);
const requireApproval =
build !== "oss" &&
isOrgLicensed &&
roleRowsForOrg.some((r) => r.roles.requireDeviceApproval);
const newClientData: InferInsertModel<typeof clients> = {
userId,
orgId: userOrg.orgId,
exitNodeId: randomExitNode.exitNodeId,
name: olm.name || "User Client",
subnet: updatedSubnet,
olmId: olm.olmId,
type: "olm",
niceId,
approvalState: requireApproval ? "pending" : null
};
// Create the client
const [newClient] = await transaction
.insert(clients)
.values({
userId,
orgId: userOrg.orgId,
exitNodeId: randomExitNode.exitNodeId,
name: olm.name || "User Client",
subnet: updatedSubnet,
olmId: olm.olmId,
type: "olm",
niceId
})
.values(newClientData)
.returning();
// create approval request
if (requireApproval) {
await transaction
.insert(approvals)
.values({
timestamp: Math.floor(new Date().getTime() / 1000),
orgId: userOrg.orgId,
clientId: newClient.clientId,
userId,
type: "user_device"
})
.returning();
}
await rebuildClientAssociationsFromClient(
newClient,
transaction
@@ -275,6 +325,7 @@ async function cleanupOrphanedClients(
if (deletedClient.olmId) {
await sendTerminateClient(
deletedClient.clientId,
OlmErrorCodes.TERMINATED_DELETED,
deletedClient.olmId
);
}

View File

@@ -2,9 +2,15 @@ import { db, orgs } from "@server/db";
import { cleanUpOldLogs as cleanUpOldAccessLogs } from "#dynamic/lib/logAccessAudit";
import { cleanUpOldLogs as cleanUpOldActionLogs } from "#dynamic/middlewares/logActionAudit";
import { cleanUpOldLogs as cleanUpOldRequestLogs } from "@server/routers/badger/logRequestAudit";
import { cleanUpOldLogs as cleanUpOldConnectionLogs } from "#dynamic/routers/newt";
import { gt, or } from "drizzle-orm";
import { cleanUpOldFingerprintSnapshots } from "@server/routers/olm/fingerprintingUtils";
import { build } from "@server/build";
export function initLogCleanupInterval() {
if (build == "saas") { // skip log cleanup for saas builds
return null;
}
return setInterval(
async () => {
const orgsToClean = await db
@@ -15,23 +21,28 @@ export function initLogCleanupInterval() {
settingsLogRetentionDaysAccess:
orgs.settingsLogRetentionDaysAccess,
settingsLogRetentionDaysRequest:
orgs.settingsLogRetentionDaysRequest
orgs.settingsLogRetentionDaysRequest,
settingsLogRetentionDaysConnection:
orgs.settingsLogRetentionDaysConnection
})
.from(orgs)
.where(
or(
gt(orgs.settingsLogRetentionDaysAction, 0),
gt(orgs.settingsLogRetentionDaysAccess, 0),
gt(orgs.settingsLogRetentionDaysRequest, 0)
gt(orgs.settingsLogRetentionDaysRequest, 0),
gt(orgs.settingsLogRetentionDaysConnection, 0)
)
);
// TODO: handle when there are multiple nodes doing this clearing using redis
for (const org of orgsToClean) {
const {
orgId,
settingsLogRetentionDaysAction,
settingsLogRetentionDaysAccess,
settingsLogRetentionDaysRequest
settingsLogRetentionDaysRequest,
settingsLogRetentionDaysConnection
} = org;
if (settingsLogRetentionDaysAction > 0) {
@@ -54,7 +65,16 @@ export function initLogCleanupInterval() {
settingsLogRetentionDaysRequest
);
}
if (settingsLogRetentionDaysConnection > 0) {
await cleanUpOldConnectionLogs(
orgId,
settingsLogRetentionDaysConnection
);
}
}
await cleanUpOldFingerprintSnapshots(365);
},
3 * 60 * 60 * 1000
); // every 3 hours

View File

@@ -0,0 +1,20 @@
import semver from "semver";
export function canCompress(
clientVersion: string | null | undefined,
type: "newt" | "olm"
): boolean {
try {
if (!clientVersion) return false;
// check if it is a valid semver
if (!semver.valid(clientVersion)) return false;
if (type === "newt") {
return semver.gte(clientVersion, "1.10.3");
} else if (type === "olm") {
return semver.gte(clientVersion, "1.4.3");
}
return false;
} catch {
return false;
}
}

View File

@@ -84,6 +84,10 @@ export class Config {
?.disable_basic_wireguard_sites
? "true"
: "false";
process.env.FLAGS_DISABLE_PRODUCT_HELP_BANNERS = parsedConfig.flags
?.disable_product_help_banners
? "true"
: "false";
process.env.PRODUCT_UPDATES_NOTIFICATION_ENABLED = parsedConfig.app
.notifications.product_updates
@@ -103,6 +107,11 @@ export class Config {
process.env.MAXMIND_ASN_PATH = parsedConfig.server.maxmind_asn_path;
}
process.env.DISABLE_ENTERPRISE_FEATURES = parsedConfig.flags
?.disable_enterprise_features
? "true"
: "false";
this.rawConfig = parsedConfig;
}

View File

@@ -2,7 +2,7 @@ import path from "path";
import { fileURLToPath } from "url";
// This is a placeholder value replaced by the build process
export const APP_VERSION = "1.13.1";
export const APP_VERSION = "1.16.0";
export const __FILENAME = fileURLToPath(import.meta.url);
export const __DIRNAME = path.dirname(__FILENAME);

View File

@@ -1,197 +0,0 @@
import { isValidCIDR } from "@server/lib/validators";
import { getNextAvailableOrgSubnet } from "@server/lib/ip";
import {
actions,
apiKeyOrg,
apiKeys,
db,
domains,
Org,
orgDomains,
orgs,
roleActions,
roles,
userOrgs
} from "@server/db";
import { eq } from "drizzle-orm";
import { defaultRoleAllowedActions } from "@server/routers/role";
import { FeatureId, limitsService, sandboxLimitSet } from "@server/lib/billing";
import { createCustomer } from "#dynamic/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import config from "@server/lib/config";
export async function createUserAccountOrg(
userId: string,
userEmail: string
): Promise<{
success: boolean;
org?: {
orgId: string;
name: string;
subnet: string;
};
error?: string;
}> {
// const subnet = await getNextAvailableOrgSubnet();
const orgId = "org_" + userId;
const name = `${userEmail}'s Organization`;
// if (!isValidCIDR(subnet)) {
// return {
// success: false,
// error: "Invalid subnet format. Please provide a valid CIDR notation."
// };
// }
// // make sure the subnet is unique
// const subnetExists = await db
// .select()
// .from(orgs)
// .where(eq(orgs.subnet, subnet))
// .limit(1);
// if (subnetExists.length > 0) {
// return { success: false, error: `Subnet ${subnet} already exists` };
// }
// make sure the orgId is unique
const orgExists = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (orgExists.length > 0) {
return {
success: false,
error: `Organization with ID ${orgId} already exists`
};
}
let error = "";
let org: Org | null = null;
await db.transaction(async (trx) => {
const allDomains = await trx
.select()
.from(domains)
.where(eq(domains.configManaged, true));
const utilitySubnet = config.getRawConfig().orgs.utility_subnet_group;
const newOrg = await trx
.insert(orgs)
.values({
orgId,
name,
// subnet
subnet: "100.90.128.0/24", // TODO: this should not be hardcoded - or can it be the same in all orgs?
utilitySubnet: utilitySubnet,
createdAt: new Date().toISOString()
})
.returning();
if (newOrg.length === 0) {
error = "Failed to create organization";
trx.rollback();
return;
}
org = newOrg[0];
// Create admin role within the same transaction
const [insertedRole] = await trx
.insert(roles)
.values({
orgId: newOrg[0].orgId,
isAdmin: true,
name: "Admin",
description: "Admin role with the most permissions"
})
.returning({ roleId: roles.roleId });
if (!insertedRole || !insertedRole.roleId) {
error = "Failed to create Admin role";
trx.rollback();
return;
}
const roleId = insertedRole.roleId;
// Get all actions and create role actions
const actionIds = await trx.select().from(actions).execute();
if (actionIds.length > 0) {
await trx.insert(roleActions).values(
actionIds.map((action) => ({
roleId,
actionId: action.actionId,
orgId: newOrg[0].orgId
}))
);
}
if (allDomains.length) {
await trx.insert(orgDomains).values(
allDomains.map((domain) => ({
orgId: newOrg[0].orgId,
domainId: domain.domainId
}))
);
}
await trx.insert(userOrgs).values({
userId,
orgId: newOrg[0].orgId,
roleId: roleId,
isOwner: true
});
const memberRole = await trx
.insert(roles)
.values({
name: "Member",
description: "Members can only view resources",
orgId
})
.returning();
await trx.insert(roleActions).values(
defaultRoleAllowedActions.map((action) => ({
roleId: memberRole[0].roleId,
actionId: action,
orgId
}))
);
});
await limitsService.applyLimitSetToOrg(orgId, sandboxLimitSet);
if (!org) {
return { success: false, error: "Failed to create org" };
}
if (error) {
return {
success: false,
error: `Failed to create org: ${error}`
};
}
// make sure we have the stripe customer
const customerId = await createCustomer(orgId, userEmail);
if (customerId) {
await usageService.updateDaily(orgId, FeatureId.USERS, 1, customerId); // Only 1 because we are crating the org
}
return {
org: {
orgId,
name,
// subnet
subnet: "100.90.128.0/24"
},
success: true
};
}

246
server/lib/deleteOrg.ts Normal file
View File

@@ -0,0 +1,246 @@
import {
clients,
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
domains,
exitNodeOrgs,
exitNodes,
olms,
orgDomains,
orgs,
remoteExitNodes,
resources,
sites,
userOrgs
} from "@server/db";
import { newts, newtSessions } from "@server/db";
import { eq, and, inArray, sql, count, countDistinct } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { sendToClient } from "#dynamic/routers/ws";
import { deletePeer } from "@server/routers/gerbil/peers";
import { OlmErrorCodes } from "@server/routers/olm/error";
import { sendTerminateClient } from "@server/routers/client/terminate";
import { usageService } from "./billing/usageService";
import { FeatureId } from "./billing";
export type DeleteOrgByIdResult = {
deletedNewtIds: string[];
olmsToTerminate: string[];
};
/**
* Deletes one organization and its related data. Returns ids for termination
* messages; caller should call sendTerminationMessages with the result.
* Throws if org not found.
*/
export async function deleteOrgById(
orgId: string
): Promise<DeleteOrgByIdResult> {
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
throw createHttpError(
HttpCode.NOT_FOUND,
`Organization with ID ${orgId} not found`
);
}
const orgSites = await db
.select()
.from(sites)
.where(eq(sites.orgId, orgId))
.limit(1);
const orgClients = await db
.select()
.from(clients)
.where(eq(clients.orgId, orgId));
const deletedNewtIds: string[] = [];
const olmsToTerminate: string[] = [];
let domainCount: number | null = null;
let siteCount: number | null = null;
let userCount: number | null = null;
let remoteExitNodeCount: number | null = null;
await db.transaction(async (trx) => {
for (const site of orgSites) {
if (site.pubKey) {
if (site.type == "wireguard") {
await deletePeer(site.exitNodeId!, site.pubKey);
} else if (site.type == "newt") {
const [deletedNewt] = await trx
.delete(newts)
.where(eq(newts.siteId, site.siteId))
.returning();
if (deletedNewt) {
deletedNewtIds.push(deletedNewt.newtId);
await trx
.delete(newtSessions)
.where(eq(newtSessions.newtId, deletedNewt.newtId));
}
}
}
logger.info(`Deleting site ${site.siteId}`);
await trx.delete(sites).where(eq(sites.siteId, site.siteId));
}
for (const client of orgClients) {
const [olm] = await trx
.select()
.from(olms)
.where(eq(olms.clientId, client.clientId))
.limit(1);
if (olm) {
olmsToTerminate.push(olm.olmId);
}
logger.info(`Deleting client ${client.clientId}`);
await trx
.delete(clients)
.where(eq(clients.clientId, client.clientId));
await trx
.delete(clientSiteResourcesAssociationsCache)
.where(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
);
await trx
.delete(clientSitesAssociationsCache)
.where(
eq(clientSitesAssociationsCache.clientId, client.clientId)
);
}
await trx.delete(resources).where(eq(resources.orgId, orgId));
const allOrgDomains = await trx
.select()
.from(orgDomains)
.innerJoin(domains, eq(orgDomains.domainId, domains.domainId))
.where(
and(
eq(orgDomains.orgId, orgId),
eq(domains.configManaged, false)
)
);
logger.info(`Found ${allOrgDomains.length} domains to delete`);
const domainIdsToDelete: string[] = [];
for (const orgDomain of allOrgDomains) {
const domainId = orgDomain.domains.domainId;
const [orgCount] = await trx
.select({ count: count() })
.from(orgDomains)
.where(eq(orgDomains.domainId, domainId));
logger.info(`Found ${orgCount.count} orgs using domain ${domainId}`);
if (orgCount.count === 1) {
domainIdsToDelete.push(domainId);
}
}
logger.info(`Found ${domainIdsToDelete.length} domains to delete`);
if (domainIdsToDelete.length > 0) {
await trx
.delete(domains)
.where(inArray(domains.domainId, domainIdsToDelete));
}
await usageService.add(orgId, FeatureId.ORGINIZATIONS, -1, trx); // here we are decreasing the org count BEFORE deleting the org because we need to still be able to get the org to get the billing org inside of here
await trx.delete(orgs).where(eq(orgs.orgId, orgId));
if (org.billingOrgId) {
const billingOrgs = await trx
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, org.billingOrgId));
if (billingOrgs.length > 0) {
const billingOrgIds = billingOrgs.map((org) => org.orgId);
const [domainCountRes] = await trx
.select({ count: count() })
.from(orgDomains)
.where(inArray(orgDomains.orgId, billingOrgIds));
domainCount = domainCountRes.count;
const [siteCountRes] = await trx
.select({ count: count() })
.from(sites)
.where(inArray(sites.orgId, billingOrgIds));
siteCount = siteCountRes.count;
const [userCountRes] = await trx
.select({ count: countDistinct(userOrgs.userId) })
.from(userOrgs)
.where(inArray(userOrgs.orgId, billingOrgIds));
userCount = userCountRes.count;
const [remoteExitNodeCountRes] = await trx
.select({ count: countDistinct(exitNodeOrgs.exitNodeId) })
.from(exitNodeOrgs)
.where(inArray(exitNodeOrgs.orgId, billingOrgIds));
remoteExitNodeCount = remoteExitNodeCountRes.count;
}
}
});
if (org.billingOrgId) {
usageService.updateCount(
org.billingOrgId,
FeatureId.DOMAINS,
domainCount ?? 0
);
usageService.updateCount(
org.billingOrgId,
FeatureId.SITES,
siteCount ?? 0
);
usageService.updateCount(
org.billingOrgId,
FeatureId.USERS,
userCount ?? 0
);
usageService.updateCount(
org.billingOrgId,
FeatureId.REMOTE_EXIT_NODES,
remoteExitNodeCount ?? 0
);
}
return { deletedNewtIds, olmsToTerminate };
}
export function sendTerminationMessages(result: DeleteOrgByIdResult): void {
for (const newtId of result.deletedNewtIds) {
sendToClient(newtId, { type: `newt/wg/terminate`, data: {} }).catch(
(error) => {
logger.error(
"Failed to send termination message to newt:",
error
);
}
);
}
for (const olmId of result.olmsToTerminate) {
sendTerminateClient(0, OlmErrorCodes.TERMINATED_REKEYED, olmId).catch(
(error) => {
logger.error(
"Failed to send termination message to olm:",
error
);
}
);
}
}

View File

@@ -0,0 +1,3 @@
export const getEnvOrYaml = (envVar: string) => (valFromYaml: any) => {
return process.env[envVar] ?? valFromYaml;
};

View File

@@ -1,15 +1,10 @@
import {
clientSitesAssociationsCache,
db,
SiteResource,
siteResources,
Transaction
} from "@server/db";
import { db, SiteResource, siteResources, Transaction } from "@server/db";
import { clients, orgs, sites } from "@server/db";
import { and, eq, isNotNull } from "drizzle-orm";
import config from "@server/lib/config";
import z from "zod";
import logger from "@server/logger";
import semver from "semver";
interface IPRange {
start: bigint;
@@ -307,6 +302,26 @@ export function isIpInCidr(ip: string, cidr: string): boolean {
return ipBigInt >= range.start && ipBigInt <= range.end;
}
/**
* Checks if two CIDR ranges overlap
* @param cidr1 First CIDR string
* @param cidr2 Second CIDR string
* @returns boolean indicating if the two CIDRs overlap
*/
export function doCidrsOverlap(cidr1: string, cidr2: string): boolean {
const version1 = detectIpVersion(cidr1.split("/")[0]);
const version2 = detectIpVersion(cidr2.split("/")[0]);
if (version1 !== version2) {
// Different IP versions cannot overlap
return false;
}
const range1 = cidrToRange(cidr1);
const range2 = cidrToRange(cidr2);
// Overlap if the ranges intersect
return range1.start <= range2.end && range2.start <= range1.end;
}
export async function getNextAvailableClientSubnet(
orgId: string,
transaction: Transaction | typeof db = db
@@ -472,10 +487,12 @@ export function generateAliasConfig(allSiteResources: SiteResource[]): Alias[] {
export type SubnetProxyTarget = {
sourcePrefix: string; // must be a cidr
destPrefix: string; // must be a cidr
disableIcmp?: boolean;
rewriteTo?: string; // must be a cidr
portRange?: {
min: number;
max: number;
protocol: "tcp" | "udp";
}[];
};
@@ -505,6 +522,11 @@ export function generateSubnetProxyTargets(
}
const clientPrefix = `${clientSite.subnet.split("/")[0]}/32`;
const portRange = [
...parsePortRangeString(siteResource.tcpPortRangeString, "tcp"),
...parsePortRangeString(siteResource.udpPortRangeString, "udp")
];
const disableIcmp = siteResource.disableIcmp ?? false;
if (siteResource.mode == "host") {
let destination = siteResource.destination;
@@ -515,7 +537,9 @@ export function generateSubnetProxyTargets(
targets.push({
sourcePrefix: clientPrefix,
destPrefix: destination
destPrefix: destination,
portRange,
disableIcmp
});
}
@@ -524,13 +548,17 @@ export function generateSubnetProxyTargets(
targets.push({
sourcePrefix: clientPrefix,
destPrefix: `${siteResource.aliasAddress}/32`,
rewriteTo: destination
rewriteTo: destination,
portRange,
disableIcmp
});
}
} else if (siteResource.mode == "cidr") {
targets.push({
sourcePrefix: clientPrefix,
destPrefix: siteResource.destination
destPrefix: siteResource.destination,
portRange,
disableIcmp
});
}
}
@@ -542,3 +570,276 @@ export function generateSubnetProxyTargets(
return targets;
}
export type SubnetProxyTargetV2 = {
sourcePrefixes: string[]; // must be cidrs
destPrefix: string; // must be a cidr
disableIcmp?: boolean;
rewriteTo?: string; // must be a cidr
portRange?: {
min: number;
max: number;
protocol: "tcp" | "udp";
}[];
resourceId?: number;
};
export function generateSubnetProxyTargetV2(
siteResource: SiteResource,
clients: {
clientId: number;
pubKey: string | null;
subnet: string | null;
}[]
): SubnetProxyTargetV2 | undefined {
if (clients.length === 0) {
logger.debug(
`No clients have access to site resource ${siteResource.siteResourceId}, skipping target generation.`
);
return;
}
let target: SubnetProxyTargetV2 | null = null;
const portRange = [
...parsePortRangeString(siteResource.tcpPortRangeString, "tcp"),
...parsePortRangeString(siteResource.udpPortRangeString, "udp")
];
const disableIcmp = siteResource.disableIcmp ?? false;
if (siteResource.mode == "host") {
let destination = siteResource.destination;
// check if this is a valid ip
const ipSchema = z.union([z.ipv4(), z.ipv6()]);
if (ipSchema.safeParse(destination).success) {
destination = `${destination}/32`;
target = {
sourcePrefixes: [],
destPrefix: destination,
portRange,
disableIcmp,
resourceId: siteResource.siteResourceId,
};
}
if (siteResource.alias && siteResource.aliasAddress) {
// also push a match for the alias address
target = {
sourcePrefixes: [],
destPrefix: `${siteResource.aliasAddress}/32`,
rewriteTo: destination,
portRange,
disableIcmp,
resourceId: siteResource.siteResourceId,
};
}
} else if (siteResource.mode == "cidr") {
target = {
sourcePrefixes: [],
destPrefix: siteResource.destination,
portRange,
disableIcmp,
resourceId: siteResource.siteResourceId,
};
}
if (!target) {
return;
}
for (const clientSite of clients) {
if (!clientSite.subnet) {
logger.debug(
`Client ${clientSite.clientId} has no subnet, skipping for site resource ${siteResource.siteResourceId}.`
);
continue;
}
const clientPrefix = `${clientSite.subnet.split("/")[0]}/32`;
// add client prefix to source prefixes
target.sourcePrefixes.push(clientPrefix);
}
// print a nice representation of the targets
// logger.debug(
// `Generated subnet proxy targets for: ${JSON.stringify(targets, null, 2)}`
// );
return target;
}
/**
* Converts a SubnetProxyTargetV2 to an array of SubnetProxyTarget (v1)
* by expanding each source prefix into its own target entry.
* @param targetV2 - The v2 target to convert
* @returns Array of v1 SubnetProxyTarget objects
*/
export function convertSubnetProxyTargetsV2ToV1(
targetsV2: SubnetProxyTargetV2[]
): SubnetProxyTarget[] {
return targetsV2.flatMap((targetV2) =>
targetV2.sourcePrefixes.map((sourcePrefix) => ({
sourcePrefix,
destPrefix: targetV2.destPrefix,
...(targetV2.disableIcmp !== undefined && {
disableIcmp: targetV2.disableIcmp
}),
...(targetV2.rewriteTo !== undefined && {
rewriteTo: targetV2.rewriteTo
}),
...(targetV2.portRange !== undefined && {
portRange: targetV2.portRange
})
}))
);
}
// Custom schema for validating port range strings
// Format: "80,443,8000-9000" or "*" for all ports, or empty string
export const portRangeStringSchema = z
.string()
.optional()
.refine(
(val) => {
if (!val || val.trim() === "" || val.trim() === "*") {
return true;
}
// Split by comma and validate each part
const parts = val.split(",").map((p) => p.trim());
for (const part of parts) {
if (part === "") {
return false; // empty parts not allowed
}
// Check if it's a range (contains dash)
if (part.includes("-")) {
const [start, end] = part.split("-").map((p) => p.trim());
// Both parts must be present
if (!start || !end) {
return false;
}
const startPort = parseInt(start, 10);
const endPort = parseInt(end, 10);
// Must be valid numbers
if (isNaN(startPort) || isNaN(endPort)) {
return false;
}
// Must be valid port range (1-65535)
if (
startPort < 1 ||
startPort > 65535 ||
endPort < 1 ||
endPort > 65535
) {
return false;
}
// Start must be <= end
if (startPort > endPort) {
return false;
}
} else {
// Single port
const port = parseInt(part, 10);
// Must be a valid number
if (isNaN(port)) {
return false;
}
// Must be valid port range (1-65535)
if (port < 1 || port > 65535) {
return false;
}
}
}
return true;
},
{
message:
'Port range must be "*" for all ports, or a comma-separated list of ports and ranges (e.g., "80,443,8000-9000"). Ports must be between 1 and 65535, and ranges must have start <= end.'
}
);
/**
* Parses a port range string into an array of port range objects
* @param portRangeStr - Port range string (e.g., "80,443,8000-9000", "*", or "")
* @param protocol - Protocol to use for all ranges (default: "tcp")
* @returns Array of port range objects with min, max, and protocol fields
*/
export function parsePortRangeString(
portRangeStr: string | undefined | null,
protocol: "tcp" | "udp" = "tcp"
): { min: number; max: number; protocol: "tcp" | "udp" }[] {
// Handle undefined or empty string - insert dummy value with port 0
if (!portRangeStr || portRangeStr.trim() === "") {
return [{ min: 0, max: 0, protocol }];
}
// Handle wildcard - return empty array (all ports allowed)
if (portRangeStr.trim() === "*") {
return [];
}
const result: { min: number; max: number; protocol: "tcp" | "udp" }[] = [];
const parts = portRangeStr.split(",").map((p) => p.trim());
for (const part of parts) {
if (part.includes("-")) {
// Range
const [start, end] = part.split("-").map((p) => p.trim());
const startPort = parseInt(start, 10);
const endPort = parseInt(end, 10);
result.push({ min: startPort, max: endPort, protocol });
} else {
// Single port
const port = parseInt(part, 10);
result.push({ min: port, max: port, protocol });
}
}
return result;
}
export function stripPortFromHost(ip: string, badgerVersion?: string): string {
const isNewerBadger =
badgerVersion &&
semver.valid(badgerVersion) &&
semver.gte(badgerVersion, "1.3.1");
if (isNewerBadger) {
return ip;
}
if (ip.startsWith("[") && ip.includes("]")) {
// if brackets are found, extract the IPv6 address from between the brackets
const ipv6Match = ip.match(/\[(.*?)\]/);
if (ipv6Match) {
return ipv6Match[1];
}
}
// Check if it looks like IPv4 (contains dots and matches IPv4 pattern)
// IPv4 format: x.x.x.x where x is 0-255
const ipv4Pattern = /^(\d{1,3}\.){3}\d{1,3}/;
if (ipv4Pattern.test(ip)) {
const lastColonIndex = ip.lastIndexOf(":");
if (lastColonIndex !== -1) {
return ip.substring(0, lastColonIndex);
}
}
// Return as is
return ip;
}

View File

@@ -0,0 +1,8 @@
import { Tier } from "@server/types/Tiers";
export async function isLicensedOrSubscribed(
orgId: string,
tiers: Tier[]
): Promise<boolean> {
return false;
}

View File

@@ -0,0 +1,8 @@
import { Tier } from "@server/types/Tiers";
export async function isSubscribed(
orgId: string,
tiers: Tier[]
): Promise<boolean> {
return false;
}

View File

@@ -0,0 +1,18 @@
/**
* Normalizes a post-authentication path for safe use when building redirect URLs.
* Returns a path that starts with / and does not allow open redirects (no //, no :).
*/
export function normalizePostAuthPath(path: string | null | undefined): string | null {
if (path == null || typeof path !== "string") {
return null;
}
const trimmed = path.trim();
if (trimmed === "") {
return null;
}
// Reject protocol-relative (//) or scheme (:) to avoid open redirect
if (trimmed.includes("//") || trimmed.includes(":")) {
return null;
}
return trimmed.startsWith("/") ? trimmed : `/${trimmed}`;
}

View File

@@ -3,13 +3,10 @@ import yaml from "js-yaml";
import { configFilePath1, configFilePath2 } from "./consts";
import { z } from "zod";
import stoi from "./stoi";
import { getEnvOrYaml } from "./getEnvOrYaml";
const portSchema = z.number().positive().gt(0).lte(65535);
const getEnvOrYaml = (envVar: string) => (valFromYaml: any) => {
return process.env[envVar] ?? valFromYaml;
};
export const configSchema = z
.object({
app: z
@@ -82,6 +79,7 @@ export const configSchema = z
.default(3001)
.transform(stoi)
.pipe(portSchema),
badger_override: z.string().optional(),
next_port: portSchema
.optional()
.default(3002)
@@ -192,6 +190,46 @@ export const configSchema = z
.prefault({})
})
.optional(),
postgres_logs: z
.object({
connection_string: z
.string()
.optional()
.transform(getEnvOrYaml("POSTGRES_LOGS_CONNECTION_STRING")),
replicas: z
.array(
z.object({
connection_string: z.string()
})
)
.optional(),
pool: z
.object({
max_connections: z
.number()
.positive()
.optional()
.default(20),
max_replica_connections: z
.number()
.positive()
.optional()
.default(10),
idle_timeout_ms: z
.number()
.positive()
.optional()
.default(30000),
connection_timeout_ms: z
.number()
.positive()
.optional()
.default(5000)
})
.optional()
.prefault({})
})
.optional(),
traefik: z
.object({
http_entrypoint: z.string().optional().default("web"),
@@ -256,17 +294,17 @@ export const configSchema = z
orgs: z
.object({
block_size: z.number().positive().gt(0).optional().default(24),
subnet_group: z.string().optional().default("100.90.128.0/24"),
subnet_group: z.string().optional().default("100.90.128.0/20"),
utility_subnet_group: z
.string()
.optional()
.default("100.96.128.0/24") //just hardcode this for now as well
.default("100.96.128.0/20") //just hardcode this for now as well
})
.optional()
.default({
block_size: 24,
subnet_group: "100.90.128.0/24",
utility_subnet_group: "100.96.128.0/24"
subnet_group: "100.90.128.0/20",
utility_subnet_group: "100.96.128.0/20"
}),
rate_limits: z
.object({
@@ -311,7 +349,10 @@ export const configSchema = z
.object({
smtp_host: z.string().optional(),
smtp_port: portSchema.optional(),
smtp_user: z.string().optional(),
smtp_user: z
.string()
.optional()
.transform(getEnvOrYaml("EMAIL_SMTP_USER")),
smtp_pass: z
.string()
.optional()
@@ -330,7 +371,9 @@ export const configSchema = z
enable_integration_api: z.boolean().optional(),
disable_local_sites: z.boolean().optional(),
disable_basic_wireguard_sites: z.boolean().optional(),
disable_config_managed_domains: z.boolean().optional()
disable_config_managed_domains: z.boolean().optional(),
disable_product_help_banners: z.boolean().optional(),
disable_enterprise_features: z.boolean().optional()
})
.optional(),
dns: z

View File

@@ -14,6 +14,7 @@ import {
siteResources,
sites,
Transaction,
userOrgRoles,
userOrgs,
userSiteResources
} from "@server/db";
@@ -32,7 +33,7 @@ import logger from "@server/logger";
import {
generateAliasConfig,
generateRemoteSubnets,
generateSubnetProxyTargets,
generateSubnetProxyTargetV2,
parseEndpoint,
formatEndpoint
} from "@server/lib/ip";
@@ -77,10 +78,10 @@ export async function getClientSiteResourceAccess(
// get all of the users in these roles
const userIdsFromRoles = await trx
.select({
userId: userOrgs.userId
userId: userOrgRoles.userId
})
.from(userOrgs)
.where(inArray(userOrgs.roleId, roleIds))
.from(userOrgRoles)
.where(inArray(userOrgRoles.roleId, roleIds))
.then((rows) => rows.map((row) => row.userId));
const newAllUserIds = Array.from(
@@ -477,6 +478,7 @@ async function handleMessagesForSiteClients(
}
if (isAdd) {
// TODO: if we are in jit mode here should we really be sending this?
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
@@ -571,7 +573,7 @@ export async function updateClientSiteDestinations(
destinations: [
{
destinationIP: site.sites.subnet.split("/")[0],
destinationPort: site.sites.listenPort || 0
destinationPort: site.sites.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
}
]
};
@@ -579,7 +581,7 @@ export async function updateClientSiteDestinations(
// add to the existing destinations
destinations.destinations.push({
destinationIP: site.sites.subnet.split("/")[0],
destinationPort: site.sites.listenPort || 0
destinationPort: site.sites.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
});
}
@@ -659,17 +661,18 @@ async function handleSubnetProxyTargetUpdates(
);
if (addedClients.length > 0) {
const targetsToAdd = generateSubnetProxyTargets(
const targetToAdd = generateSubnetProxyTargetV2(
siteResource,
addedClients
);
if (targetsToAdd.length > 0) {
logger.info(
`Adding ${targetsToAdd.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}`
);
if (targetToAdd) {
proxyJobs.push(
addSubnetProxyTargets(newt.newtId, targetsToAdd)
addSubnetProxyTargets(
newt.newtId,
[targetToAdd],
newt.version
)
);
}
@@ -695,17 +698,18 @@ async function handleSubnetProxyTargetUpdates(
);
if (removedClients.length > 0) {
const targetsToRemove = generateSubnetProxyTargets(
const targetToRemove = generateSubnetProxyTargetV2(
siteResource,
removedClients
);
if (targetsToRemove.length > 0) {
logger.info(
`Removing ${targetsToRemove.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}`
);
if (targetToRemove) {
proxyJobs.push(
removeSubnetProxyTargets(newt.newtId, targetsToRemove)
removeSubnetProxyTargets(
newt.newtId,
[targetToRemove],
newt.version
)
);
}
@@ -811,12 +815,12 @@ export async function rebuildClientAssociationsFromClient(
// Role-based access
const roleIds = await trx
.select({ roleId: userOrgs.roleId })
.from(userOrgs)
.select({ roleId: userOrgRoles.roleId })
.from(userOrgRoles)
.where(
and(
eq(userOrgs.userId, client.userId),
eq(userOrgs.orgId, client.orgId)
eq(userOrgRoles.userId, client.userId),
eq(userOrgRoles.orgId, client.orgId)
)
) // this needs to be locked onto this org or else cross-org access could happen
.then((rows) => rows.map((row) => row.roleId));
@@ -1080,6 +1084,7 @@ async function handleMessagesForClientSites(
continue;
}
// TODO: if we are in jit mode here should we really be sending this?
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
@@ -1146,7 +1151,7 @@ async function handleMessagesForClientResources(
// Add subnet proxy targets for each site
for (const [siteId, resources] of addedBySite.entries()) {
const [newt] = await trx
.select({ newtId: newts.newtId })
.select({ newtId: newts.newtId, version: newts.version })
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
@@ -1159,7 +1164,7 @@ async function handleMessagesForClientResources(
}
for (const resource of resources) {
const targets = generateSubnetProxyTargets(resource, [
const target = generateSubnetProxyTargetV2(resource, [
{
clientId: client.clientId,
pubKey: client.pubKey,
@@ -1167,8 +1172,14 @@ async function handleMessagesForClientResources(
}
]);
if (targets.length > 0) {
proxyJobs.push(addSubnetProxyTargets(newt.newtId, targets));
if (target) {
proxyJobs.push(
addSubnetProxyTargets(
newt.newtId,
[target],
newt.version
)
);
}
try {
@@ -1217,7 +1228,7 @@ async function handleMessagesForClientResources(
// Remove subnet proxy targets for each site
for (const [siteId, resources] of removedBySite.entries()) {
const [newt] = await trx
.select({ newtId: newts.newtId })
.select({ newtId: newts.newtId, version: newts.version })
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
@@ -1230,7 +1241,7 @@ async function handleMessagesForClientResources(
}
for (const resource of resources) {
const targets = generateSubnetProxyTargets(resource, [
const target = generateSubnetProxyTargetV2(resource, [
{
clientId: client.clientId,
pubKey: client.pubKey,
@@ -1238,9 +1249,13 @@ async function handleMessagesForClientResources(
}
]);
if (targets.length > 0) {
if (target) {
proxyJobs.push(
removeSubnetProxyTargets(newt.newtId, targets)
removeSubnetProxyTargets(
newt.newtId,
[target],
newt.version
)
);
}

View File

@@ -1,16 +0,0 @@
export enum AudienceIds {
SignUps = "",
Subscribed = "",
Churned = "",
Newsletter = ""
}
let resend;
export default resend;
export async function moveEmailToAudience(
email: string,
audienceId: AudienceIds
) {
return;
}

40
server/lib/sanitize.ts Normal file
View File

@@ -0,0 +1,40 @@
/**
* Sanitize a string field before inserting into a database TEXT column.
*
* Two passes are applied:
*
* 1. Lone UTF-16 surrogates JavaScript strings can hold unpaired surrogates
* (e.g. \uD800 without a following \uDC00-\uDFFF codepoint). These are
* valid in JS but cannot be encoded as UTF-8, triggering
* `report_invalid_encoding` in SQLite / Postgres. They are replaced with
* the Unicode replacement character U+FFFD so the data is preserved as a
* visible signal that something was malformed.
*
* 2. Null bytes and C0 control characters SQLite stores TEXT as
* null-terminated C strings, so \x00 in a value causes
* `report_invalid_encoding`. Bots and scanners routinely inject null bytes
* into URLs (e.g. `/path\u0000.jpg`). All C0 control characters in the
* range \x00-\x1F are stripped except for the three that are legitimate in
* text payloads: HT (\x09), LF (\x0A), and CR (\x0D). DEL (\x7F) is also
* stripped.
*/
export function sanitizeString(value: string): string;
export function sanitizeString(
value: string | null | undefined
): string | undefined;
export function sanitizeString(
value: string | null | undefined
): string | undefined {
if (value == null) return undefined;
return (
value
// Replace lone high surrogates (not followed by a low surrogate)
// and lone low surrogates (not preceded by a high surrogate).
.replace(
/[\uD800-\uDBFF](?![\uDC00-\uDFFF])|(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]/g,
"\uFFFD"
)
// Strip null bytes, C0 control chars (except HT/LF/CR), and DEL.
.replace(/[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]/g, "")
);
}

434
server/lib/sshCA.ts Normal file
View File

@@ -0,0 +1,434 @@
import * as crypto from "crypto";
/**
* SSH CA "Server" - Pure TypeScript Implementation
*
* This module provides basic SSH Certificate Authority functionality using
* only Node.js built-in crypto module. No external dependencies or subprocesses.
*
* Usage:
* 1. generateCA() - Creates a new CA key pair, returns CA info including the
* TrustedUserCAKeys line to add to servers
* 2. signPublicKey() - Signs a user's public key with the CA, returns a certificate
*/
// ============================================================================
// SSH Wire Format Helpers
// ============================================================================
/**
* Encode a string in SSH wire format (4-byte length prefix + data)
*/
function encodeString(data: Buffer | string): Buffer {
const buf = typeof data === "string" ? Buffer.from(data, "utf8") : data;
const len = Buffer.alloc(4);
len.writeUInt32BE(buf.length, 0);
return Buffer.concat([len, buf]);
}
/**
* Encode a uint32 in SSH wire format (big-endian)
*/
function encodeUInt32(value: number): Buffer {
const buf = Buffer.alloc(4);
buf.writeUInt32BE(value, 0);
return buf;
}
/**
* Encode a uint64 in SSH wire format (big-endian)
*/
function encodeUInt64(value: bigint): Buffer {
const buf = Buffer.alloc(8);
buf.writeBigUInt64BE(value, 0);
return buf;
}
/**
* Decode a string from SSH wire format at the given offset
* Returns the string buffer and the new offset
*/
function decodeString(
data: Buffer,
offset: number
): { value: Buffer; newOffset: number } {
const len = data.readUInt32BE(offset);
const value = data.subarray(offset + 4, offset + 4 + len);
return { value, newOffset: offset + 4 + len };
}
// ============================================================================
// SSH Public Key Parsing/Encoding
// ============================================================================
/**
* Parse an OpenSSH public key line (e.g., "ssh-ed25519 AAAA... comment")
*/
function parseOpenSSHPublicKey(pubKeyLine: string): {
keyType: string;
keyData: Buffer;
comment: string;
} {
const parts = pubKeyLine.trim().split(/\s+/);
if (parts.length < 2) {
throw new Error("Invalid public key format");
}
const keyType = parts[0];
const keyData = Buffer.from(parts[1], "base64");
const comment = parts.slice(2).join(" ") || "";
// Verify the key type in the blob matches
const { value: blobKeyType } = decodeString(keyData, 0);
if (blobKeyType.toString("utf8") !== keyType) {
throw new Error(
`Key type mismatch: ${blobKeyType.toString("utf8")} vs ${keyType}`
);
}
return { keyType, keyData, comment };
}
/**
* Encode an Ed25519 public key in OpenSSH format
*/
function encodeEd25519PublicKey(publicKey: Buffer): Buffer {
return Buffer.concat([
encodeString("ssh-ed25519"),
encodeString(publicKey)
]);
}
/**
* Format a public key blob as an OpenSSH public key line
*/
function formatOpenSSHPublicKey(keyBlob: Buffer, comment: string = ""): string {
const { value: keyType } = decodeString(keyBlob, 0);
const base64 = keyBlob.toString("base64");
return `${keyType.toString("utf8")} ${base64}${comment ? " " + comment : ""}`;
}
// ============================================================================
// SSH Certificate Building
// ============================================================================
interface CertificateOptions {
/** Serial number for the certificate */
serial?: bigint;
/** Certificate type: 1 = user, 2 = host */
certType?: number;
/** Key ID (usually username or identifier) */
keyId: string;
/** List of valid principals (usernames the cert is valid for) */
validPrincipals: string[];
/** Valid after timestamp (seconds since epoch) */
validAfter?: bigint;
/** Valid before timestamp (seconds since epoch) */
validBefore?: bigint;
/** Critical options (usually empty for user certs) */
criticalOptions?: Map<string, string>;
/** Extensions to enable */
extensions?: string[];
}
/**
* Build the extensions section of the certificate
*/
function buildExtensions(extensions: string[]): Buffer {
// Extensions are a series of name-value pairs, sorted by name
// For boolean extensions, the value is empty
const sortedExtensions = [...extensions].sort();
const parts: Buffer[] = [];
for (const ext of sortedExtensions) {
parts.push(encodeString(ext));
parts.push(encodeString("")); // Empty value for boolean extensions
}
return encodeString(Buffer.concat(parts));
}
/**
* Build the critical options section
*/
function buildCriticalOptions(options: Map<string, string>): Buffer {
const sortedKeys = [...options.keys()].sort();
const parts: Buffer[] = [];
for (const key of sortedKeys) {
parts.push(encodeString(key));
parts.push(encodeString(encodeString(options.get(key)!)));
}
return encodeString(Buffer.concat(parts));
}
/**
* Build the valid principals section
*/
function buildPrincipals(principals: string[]): Buffer {
const parts: Buffer[] = [];
for (const principal of principals) {
parts.push(encodeString(principal));
}
return encodeString(Buffer.concat(parts));
}
/**
* Extract the raw Ed25519 public key from an OpenSSH public key blob
*/
function extractEd25519PublicKey(keyBlob: Buffer): Buffer {
const { newOffset } = decodeString(keyBlob, 0); // Skip key type
const { value: publicKey } = decodeString(keyBlob, newOffset);
return publicKey;
}
// ============================================================================
// CA Interface
// ============================================================================
export interface CAKeyPair {
/** CA private key in PEM format (keep this secret!) */
privateKeyPem: string;
/** CA public key in PEM format */
publicKeyPem: string;
/** CA public key in OpenSSH format (for TrustedUserCAKeys) */
publicKeyOpenSSH: string;
/** Raw CA public key bytes (Ed25519) */
publicKeyRaw: Buffer;
}
export interface SignedCertificate {
/** The certificate in OpenSSH format (save as id_ed25519-cert.pub or similar) */
certificate: string;
/** The certificate type string */
certType: string;
/** Serial number */
serial: bigint;
/** Key ID */
keyId: string;
/** Valid principals */
validPrincipals: string[];
/** Valid from timestamp */
validAfter: Date;
/** Valid until timestamp */
validBefore: Date;
}
// ============================================================================
// Main Functions
// ============================================================================
/**
* Generate a new SSH Certificate Authority key pair.
*
* Returns the CA keys and the line to add to /etc/ssh/sshd_config:
* TrustedUserCAKeys /etc/ssh/ca.pub
*
* Then save the publicKeyOpenSSH to /etc/ssh/ca.pub on the server.
*
* @param comment - Optional comment for the CA public key
* @returns CA key pair and configuration info
*/
export function generateCA(comment: string = "pangolin-ssh-ca"): CAKeyPair {
// Generate Ed25519 key pair
const { publicKey, privateKey } = crypto.generateKeyPairSync("ed25519", {
publicKeyEncoding: { type: "spki", format: "pem" },
privateKeyEncoding: { type: "pkcs8", format: "pem" }
});
// Get raw public key bytes
const pubKeyObj = crypto.createPublicKey(publicKey);
const rawPubKey = pubKeyObj.export({ type: "spki", format: "der" });
// Ed25519 SPKI format: 12 byte header + 32 byte key
const ed25519PubKey = rawPubKey.subarray(rawPubKey.length - 32);
// Create OpenSSH format public key
const pubKeyBlob = encodeEd25519PublicKey(ed25519PubKey);
const publicKeyOpenSSH = formatOpenSSHPublicKey(pubKeyBlob, comment);
return {
privateKeyPem: privateKey,
publicKeyPem: publicKey,
publicKeyOpenSSH,
publicKeyRaw: ed25519PubKey
};
}
// ============================================================================
// Helper Functions
// ============================================================================
/**
* Get and decrypt the SSH CA keys for an organization.
*
* @param orgId - Organization ID
* @param decryptionKey - Key to decrypt the CA private key (typically server.secret from config)
* @returns CA key pair or null if not found
*/
export async function getOrgCAKeys(
orgId: string,
decryptionKey: string
): Promise<CAKeyPair | null> {
const { db, orgs } = await import("@server/db");
const { eq } = await import("drizzle-orm");
const { decrypt } = await import("@server/lib/crypto");
const [org] = await db
.select({
sshCaPrivateKey: orgs.sshCaPrivateKey,
sshCaPublicKey: orgs.sshCaPublicKey
})
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org || !org.sshCaPrivateKey || !org.sshCaPublicKey) {
return null;
}
const privateKeyPem = decrypt(org.sshCaPrivateKey, decryptionKey);
// Extract raw public key from the OpenSSH format
const { keyData } = parseOpenSSHPublicKey(org.sshCaPublicKey);
const { newOffset } = decodeString(keyData, 0); // Skip key type
const { value: publicKeyRaw } = decodeString(keyData, newOffset);
// Get PEM format of public key
const pubKeyObj = crypto.createPublicKey({
key: privateKeyPem,
format: "pem"
});
const publicKeyPem = pubKeyObj.export({
type: "spki",
format: "pem"
}) as string;
return {
privateKeyPem,
publicKeyPem,
publicKeyOpenSSH: org.sshCaPublicKey,
publicKeyRaw
};
}
/**
* Sign a user's SSH public key with the CA, producing a certificate.
*
* The resulting certificate should be saved alongside the user's private key
* with a -cert.pub suffix. For example:
* - Private key: ~/.ssh/id_ed25519
* - Certificate: ~/.ssh/id_ed25519-cert.pub
*
* @param caPrivateKeyPem - CA private key in PEM format
* @param userPublicKeyLine - User's public key in OpenSSH format
* @param options - Certificate options (principals, validity, etc.)
* @returns Signed certificate
*/
export function signPublicKey(
caPrivateKeyPem: string,
userPublicKeyLine: string,
options: CertificateOptions
): SignedCertificate {
// Parse the user's public key
const { keyType, keyData } = parseOpenSSHPublicKey(userPublicKeyLine);
// Determine certificate type string
let certTypeString: string;
if (keyType === "ssh-ed25519") {
certTypeString = "ssh-ed25519-cert-v01@openssh.com";
} else if (keyType === "ssh-rsa") {
certTypeString = "ssh-rsa-cert-v01@openssh.com";
} else if (keyType === "ecdsa-sha2-nistp256") {
certTypeString = "ecdsa-sha2-nistp256-cert-v01@openssh.com";
} else if (keyType === "ecdsa-sha2-nistp384") {
certTypeString = "ecdsa-sha2-nistp384-cert-v01@openssh.com";
} else if (keyType === "ecdsa-sha2-nistp521") {
certTypeString = "ecdsa-sha2-nistp521-cert-v01@openssh.com";
} else {
throw new Error(`Unsupported key type: ${keyType}`);
}
// Get CA public key from private key
const caPrivKey = crypto.createPrivateKey(caPrivateKeyPem);
const caPubKey = crypto.createPublicKey(caPrivKey);
const caRawPubKey = caPubKey.export({ type: "spki", format: "der" });
const caEd25519PubKey = caRawPubKey.subarray(caRawPubKey.length - 32);
const caPubKeyBlob = encodeEd25519PublicKey(caEd25519PubKey);
// Set defaults
const serial = options.serial ?? BigInt(Date.now());
const certType = options.certType ?? 1; // 1 = user cert
const now = BigInt(Math.floor(Date.now() / 1000));
const validAfter = options.validAfter ?? now - 60n; // 1 minute ago
const validBefore = options.validBefore ?? now + 86400n * 365n; // 1 year from now
// Default extensions for user certificates
const defaultExtensions = [
"permit-X11-forwarding",
"permit-agent-forwarding",
"permit-port-forwarding",
"permit-pty",
"permit-user-rc"
];
const extensions = options.extensions ?? defaultExtensions;
const criticalOptions = options.criticalOptions ?? new Map();
// Generate nonce (random bytes)
const nonce = crypto.randomBytes(32);
// Extract the public key portion from the user's key blob
// For Ed25519: skip the key type string, get the public key (already encoded)
let userKeyPortion: Buffer;
if (keyType === "ssh-ed25519") {
// Skip the key type string, take the rest (which is encodeString(32-byte-key))
const { newOffset } = decodeString(keyData, 0);
userKeyPortion = keyData.subarray(newOffset);
} else {
// For other key types, extract everything after the key type
const { newOffset } = decodeString(keyData, 0);
userKeyPortion = keyData.subarray(newOffset);
}
// Build the certificate body (to be signed)
const certBody = Buffer.concat([
encodeString(certTypeString),
encodeString(nonce),
userKeyPortion,
encodeUInt64(serial),
encodeUInt32(certType),
encodeString(options.keyId),
buildPrincipals(options.validPrincipals),
encodeUInt64(validAfter),
encodeUInt64(validBefore),
buildCriticalOptions(criticalOptions),
buildExtensions(extensions),
encodeString(""), // reserved
encodeString(caPubKeyBlob) // signature key (CA public key)
]);
// Sign the certificate body
const signature = crypto.sign(null, certBody, caPrivKey);
// Build the full signature blob (algorithm + signature)
const signatureBlob = Buffer.concat([
encodeString("ssh-ed25519"),
encodeString(signature)
]);
// Build complete certificate
const certificate = Buffer.concat([certBody, encodeString(signatureBlob)]);
// Format as OpenSSH certificate line
const certLine = `${certTypeString} ${certificate.toString("base64")} ${options.keyId}`;
return {
certificate: certLine,
certType: certTypeString,
serial,
keyId: options.keyId,
validPrincipals: options.validPrincipals,
validAfter: new Date(Number(validAfter) * 1000),
validBefore: new Date(Number(validBefore) * 1000)
};
}

22
server/lib/tokenCache.ts Normal file
View File

@@ -0,0 +1,22 @@
/**
* Returns a cached plaintext token from Redis if one exists and decrypts
* cleanly, otherwise calls `createSession` to mint a fresh token, stores the
* encrypted value in Redis with the given TTL, and returns it.
*
* Failures at the Redis layer are non-fatal the function always falls
* through to session creation so the caller is never blocked by a Redis outage.
*
* @param cacheKey Unique Redis key, e.g. `"newt:token_cache:abc123"`
* @param secret Server secret used for AES encryption/decryption
* @param ttlSeconds Cache TTL in seconds (should match session expiry)
* @param createSession Factory that mints a new session and returns its raw token
*/
export async function getOrCreateCachedToken(
cacheKey: string,
secret: string,
ttlSeconds: number,
createSession: () => Promise<string>
): Promise<string> {
const token = await createSession();
return token;
}

View File

@@ -218,10 +218,11 @@ export class TraefikConfigManager {
return true;
}
// Fetch if it's been more than 24 hours (for renewals)
const dayInMs = 24 * 60 * 60 * 1000;
const timeSinceLastFetch =
Date.now() - this.lastCertificateFetch.getTime();
// Fetch if it's been more than 24 hours (daily routine check)
if (timeSinceLastFetch > dayInMs) {
logger.info("Fetching certificates due to 24-hour renewal check");
return true;
@@ -265,7 +266,7 @@ export class TraefikConfigManager {
return true;
}
// Check if any local certificates are missing or appear to be outdated
// Check if any local certificates are missing (needs immediate fetch)
for (const domain of domainsNeedingCerts) {
const localState = this.lastLocalCertificateState.get(domain);
if (!localState || !localState.exists) {
@@ -274,17 +275,46 @@ export class TraefikConfigManager {
);
return true;
}
}
// Check if certificate is expiring soon (within 30 days)
if (localState.expiresAt) {
const nowInSeconds = Math.floor(Date.now() / 1000);
const secondsUntilExpiry = localState.expiresAt - nowInSeconds;
const daysUntilExpiry = secondsUntilExpiry / (60 * 60 * 24);
if (daysUntilExpiry < 30) {
logger.info(
`Fetching certificates due to upcoming expiry for ${domain} (${Math.round(daysUntilExpiry)} days remaining)`
);
return true;
// For expiry checks, throttle to every 6 hours to avoid querying the
// API/DB on every monitor loop. The certificate-service renews certs
// 45 days before expiry, so checking every 6 hours is plenty frequent
// to pick up renewed certs promptly.
const renewalCheckIntervalMs = 6 * 60 * 60 * 1000; // 6 hours
if (timeSinceLastFetch > renewalCheckIntervalMs) {
// Check non-wildcard certs for expiry (within 45 days to match
// the server-side renewal window in certificate-service)
for (const domain of domainsNeedingCerts) {
const localState = this.lastLocalCertificateState.get(domain);
if (localState?.expiresAt) {
const nowInSeconds = Math.floor(Date.now() / 1000);
const secondsUntilExpiry =
localState.expiresAt - nowInSeconds;
const daysUntilExpiry = secondsUntilExpiry / (60 * 60 * 24);
if (daysUntilExpiry < 45) {
logger.info(
`Fetching certificates due to upcoming expiry for ${domain} (${Math.round(daysUntilExpiry)} days remaining)`
);
return true;
}
}
}
// Also check wildcard certificates for expiry. These are not
// included in domainsNeedingCerts since their subdomains are
// filtered out, so we must check them separately.
for (const [certDomain, state] of this.lastLocalCertificateState) {
if (state.exists && state.wildcard && state.expiresAt) {
const nowInSeconds = Math.floor(Date.now() / 1000);
const secondsUntilExpiry = state.expiresAt - nowInSeconds;
const daysUntilExpiry = secondsUntilExpiry / (60 * 60 * 24);
if (daysUntilExpiry < 45) {
logger.info(
`Fetching certificates due to upcoming expiry for wildcard cert ${certDomain} (${Math.round(daysUntilExpiry)} days remaining)`
);
return true;
}
}
}
}
@@ -361,6 +391,26 @@ export class TraefikConfigManager {
}
}
// Also include wildcard cert base domains that are
// expiring or expired so they get re-fetched even though
// their subdomains were filtered out above.
for (const [certDomain, state] of this
.lastLocalCertificateState) {
if (state.exists && state.wildcard && state.expiresAt) {
const nowInSeconds = Math.floor(Date.now() / 1000);
const secondsUntilExpiry =
state.expiresAt - nowInSeconds;
const daysUntilExpiry =
secondsUntilExpiry / (60 * 60 * 24);
if (daysUntilExpiry < 45) {
domainsToFetch.add(certDomain);
logger.info(
`Including expiring wildcard cert domain ${certDomain} in fetch (${Math.round(daysUntilExpiry)} days remaining)`
);
}
}
}
if (domainsToFetch.size > 0) {
// Get valid certificates for domains not covered by wildcards
validCertificates =
@@ -507,11 +557,18 @@ export class TraefikConfigManager {
config.getRawConfig().server
.session_cookie_name,
// deprecated
accessTokenQueryParam:
config.getRawConfig().server
.resource_access_token_param,
accessTokenIdHeader:
config.getRawConfig().server
.resource_access_token_headers.id,
accessTokenHeader:
config.getRawConfig().server
.resource_access_token_headers.token,
resourceSessionRequestParam:
config.getRawConfig().server
.resource_session_request_param

View File

@@ -14,29 +14,38 @@ import logger from "@server/logger";
import config from "@server/lib/config";
import { resources, sites, Target, targets } from "@server/db";
import createPathRewriteMiddleware from "./middleware";
import { sanitize, validatePathRewriteConfig } from "./utils";
import { sanitize, encodePath, validatePathRewriteConfig } from "./utils";
const redirectHttpsMiddlewareName = "redirect-to-https";
const badgerMiddlewareName = "badger";
// Define extended target type with site information
type TargetWithSite = Target & {
resourceId: number;
targetId: number;
ip: string | null;
method: string | null;
port: number | null;
internalPort: number | null;
enabled: boolean;
health: string | null;
site: {
siteId: number;
type: string;
subnet: string | null;
exitNodeId: number | null;
online: boolean;
};
};
export async function getTraefikConfig(
exitNodeId: number,
siteTypes: string[],
filterOutNamespaceDomains = false,
generateLoginPageRouters = false,
allowRawResources = true
filterOutNamespaceDomains = false, // UNUSED BUT USED IN PRIVATE
generateLoginPageRouters = false, // UNUSED BUT USED IN PRIVATE
allowRawResources = true,
allowMaintenancePage = true // UNUSED BUT USED IN PRIVATE
): Promise<any> {
// Define extended target type with site information
type TargetWithSite = Target & {
site: {
siteId: number;
type: string;
subnet: string | null;
exitNodeId: number | null;
online: boolean;
};
};
// Get resources with their targets and sites in a single optimized query
// Start from sites on this exit node, then join to targets and resources
const resourcesWithTargetsAndSites = await db
@@ -59,6 +68,7 @@ export async function getTraefikConfig(
headers: resources.headers,
proxyProtocol: resources.proxyProtocol,
proxyProtocolVersion: resources.proxyProtocolVersion,
// Target fields
targetId: targets.targetId,
targetEnabled: targets.enabled,
@@ -103,10 +113,6 @@ export async function getTraefikConfig(
eq(sites.type, "local")
)
),
or(
ne(targetHealthCheck.hcHealth, "unhealthy"), // Exclude unhealthy targets
isNull(targetHealthCheck.hcHealth) // Include targets with no health check record
),
inArray(sites.type, siteTypes),
allowRawResources
? isNotNull(resources.http) // ignore the http check if allow_raw_resources is true
@@ -121,7 +127,7 @@ export async function getTraefikConfig(
resourcesWithTargetsAndSites.forEach((row) => {
const resourceId = row.resourceId;
const resourceName = sanitize(row.resourceName) || "";
const targetPath = sanitize(row.path) || ""; // Handle null/undefined paths
const targetPath = encodePath(row.path); // Use encodePath to avoid collisions (e.g. "/a/b" vs "/a-b")
const pathMatchType = row.pathMatchType || "";
const rewritePath = row.rewritePath || "";
const rewritePathType = row.rewritePathType || "";
@@ -139,7 +145,7 @@ export async function getTraefikConfig(
const mapKey = [resourceId, pathKey].filter(Boolean).join("-");
const key = sanitize(mapKey);
if (!resourcesMap.has(key)) {
if (!resourcesMap.has(mapKey)) {
const validation = validatePathRewriteConfig(
row.path,
row.pathMatchType,
@@ -154,9 +160,10 @@ export async function getTraefikConfig(
return;
}
resourcesMap.set(key, {
resourcesMap.set(mapKey, {
resourceId: row.resourceId,
name: resourceName,
key: key,
fullDomain: row.fullDomain,
ssl: row.ssl,
http: row.http,
@@ -184,8 +191,7 @@ export async function getTraefikConfig(
});
}
// Add target with its associated site data
resourcesMap.get(key).targets.push({
resourcesMap.get(mapKey).targets.push({
resourceId: row.resourceId,
targetId: row.targetId,
ip: row.ip,
@@ -193,6 +199,7 @@ export async function getTraefikConfig(
port: row.port,
internalPort: row.internalPort,
enabled: row.targetEnabled,
health: row.hcHealth,
site: {
siteId: row.siteId,
type: row.siteType,
@@ -221,8 +228,9 @@ export async function getTraefikConfig(
};
// get the key and the resource
for (const [key, resource] of resourcesMap.entries()) {
const targets = resource.targets;
for (const [, resource] of resourcesMap.entries()) {
const targets = resource.targets as TargetWithSite[];
const key = resource.key;
const routerName = `${key}-${resource.name}-router`;
const serviceName = `${key}-${resource.name}-service`;
@@ -470,17 +478,24 @@ export async function getTraefikConfig(
// RECEIVE BANDWIDTH ENDPOINT.
// TODO: HOW TO HANDLE ^^^^^^ BETTER
const anySitesOnline = (
targets as TargetWithSite[]
).some((target: TargetWithSite) => target.site.online);
const anySitesOnline = targets.some(
(target) =>
target.site.online ||
target.site.type === "local" ||
target.site.type === "wireguard"
);
return (
(targets as TargetWithSite[])
.filter((target: TargetWithSite) => {
targets
.filter((target) => {
if (!target.enabled) {
return false;
}
if (target.health == "unhealthy") {
return false;
}
// If any sites are online, exclude offline sites
if (anySitesOnline && !target.site.online) {
return false;
@@ -508,7 +523,7 @@ export async function getTraefikConfig(
}
return true;
})
.map((target: TargetWithSite) => {
.map((target) => {
if (
target.site.type === "local" ||
target.site.type === "wireguard"
@@ -594,16 +609,19 @@ export async function getTraefikConfig(
loadBalancer: {
servers: (() => {
// Check if any sites are online
const anySitesOnline = (
targets as TargetWithSite[]
).some((target: TargetWithSite) => target.site.online);
const anySitesOnline = targets.some(
(target) =>
target.site.online ||
target.site.type === "local" ||
target.site.type === "wireguard"
);
return (targets as TargetWithSite[])
.filter((target: TargetWithSite) => {
return targets
.filter((target) => {
if (!target.enabled) {
return false;
}
// If any sites are online, exclude offline sites
if (anySitesOnline && !target.site.online) {
return false;
@@ -626,7 +644,7 @@ export async function getTraefikConfig(
}
return true;
})
.map((target: TargetWithSite) => {
.map((target) => {
if (
target.site.type === "local" ||
target.site.type === "wireguard"

View File

@@ -0,0 +1,323 @@
import { assertEquals } from "../../../test/assert";
// ── Pure function copies (inlined to avoid pulling in server dependencies) ──
function sanitize(input: string | null | undefined): string | undefined {
if (!input) return undefined;
if (input.length > 50) {
input = input.substring(0, 50);
}
return input
.replace(/[^a-zA-Z0-9-]/g, "-")
.replace(/-+/g, "-")
.replace(/^-|-$/g, "");
}
function encodePath(path: string | null | undefined): string {
if (!path) return "";
return path.replace(/[^a-zA-Z0-9]/g, (ch) => {
return ch.charCodeAt(0).toString(16);
});
}
// ── Helpers ──────────────────────────────────────────────────────────
/**
* Exact replica of the OLD key computation from upstream main.
* Uses sanitize() for paths — this is what had the collision bug.
*/
function oldKeyComputation(
resourceId: number,
path: string | null,
pathMatchType: string | null,
rewritePath: string | null,
rewritePathType: string | null
): string {
const targetPath = sanitize(path) || "";
const pmt = pathMatchType || "";
const rp = rewritePath || "";
const rpt = rewritePathType || "";
const pathKey = [targetPath, pmt, rp, rpt].filter(Boolean).join("-");
const mapKey = [resourceId, pathKey].filter(Boolean).join("-");
return sanitize(mapKey) || "";
}
/**
* Replica of the NEW key computation from our fix.
* Uses encodePath() for paths — collision-free.
*/
function newKeyComputation(
resourceId: number,
path: string | null,
pathMatchType: string | null,
rewritePath: string | null,
rewritePathType: string | null
): string {
const targetPath = encodePath(path);
const pmt = pathMatchType || "";
const rp = rewritePath || "";
const rpt = rewritePathType || "";
const pathKey = [targetPath, pmt, rp, rpt].filter(Boolean).join("-");
const mapKey = [resourceId, pathKey].filter(Boolean).join("-");
return sanitize(mapKey) || "";
}
// ── Tests ────────────────────────────────────────────────────────────
function runTests() {
console.log("Running path encoding tests...\n");
let passed = 0;
// ── encodePath unit tests ────────────────────────────────────────
// Test 1: null/undefined/empty
{
assertEquals(encodePath(null), "", "null should return empty");
assertEquals(
encodePath(undefined),
"",
"undefined should return empty"
);
assertEquals(encodePath(""), "", "empty string should return empty");
console.log(" PASS: encodePath handles null/undefined/empty");
passed++;
}
// Test 2: root path
{
assertEquals(encodePath("/"), "2f", "/ should encode to 2f");
console.log(" PASS: encodePath encodes root path");
passed++;
}
// Test 3: alphanumeric passthrough
{
assertEquals(encodePath("/api"), "2fapi", "/api encodes slash only");
assertEquals(encodePath("/v1"), "2fv1", "/v1 encodes slash only");
assertEquals(encodePath("abc"), "abc", "plain alpha passes through");
console.log(" PASS: encodePath preserves alphanumeric chars");
passed++;
}
// Test 4: all special chars produce unique hex
{
const paths = ["/a/b", "/a-b", "/a.b", "/a_b", "/a b"];
const results = paths.map((p) => encodePath(p));
const unique = new Set(results);
assertEquals(
unique.size,
paths.length,
"all special-char paths must produce unique encodings"
);
console.log(
" PASS: encodePath produces unique output for different special chars"
);
passed++;
}
// Test 5: output is always alphanumeric (safe for Traefik names)
{
const paths = [
"/",
"/api",
"/a/b",
"/a-b",
"/a.b",
"/complex/path/here"
];
for (const p of paths) {
const e = encodePath(p);
assertEquals(
/^[a-zA-Z0-9]+$/.test(e),
true,
`encodePath("${p}") = "${e}" must be alphanumeric`
);
}
console.log(" PASS: encodePath output is always alphanumeric");
passed++;
}
// Test 6: deterministic
{
assertEquals(
encodePath("/api"),
encodePath("/api"),
"same input same output"
);
assertEquals(
encodePath("/a/b/c"),
encodePath("/a/b/c"),
"same input same output"
);
console.log(" PASS: encodePath is deterministic");
passed++;
}
// Test 7: many distinct paths never collide
{
const paths = [
"/",
"/api",
"/api/v1",
"/api/v2",
"/a/b",
"/a-b",
"/a.b",
"/a_b",
"/health",
"/health/check",
"/admin",
"/admin/users",
"/api/v1/users",
"/api/v1/posts",
"/app",
"/app/dashboard"
];
const encoded = new Set(paths.map((p) => encodePath(p)));
assertEquals(
encoded.size,
paths.length,
`expected ${paths.length} unique encodings, got ${encoded.size}`
);
console.log(" PASS: 16 realistic paths all produce unique encodings");
passed++;
}
// ── Collision fix: the actual bug we're fixing ───────────────────
// Test 8: /a/b and /a-b now have different keys (THE BUG FIX)
{
const keyAB = newKeyComputation(1, "/a/b", "prefix", null, null);
const keyDash = newKeyComputation(1, "/a-b", "prefix", null, null);
assertEquals(
keyAB !== keyDash,
true,
"/a/b and /a-b MUST have different keys"
);
console.log(" PASS: collision fix — /a/b vs /a-b have different keys");
passed++;
}
// Test 9: demonstrate the old bug — old code maps /a/b and /a-b to same key
{
const oldKeyAB = oldKeyComputation(1, "/a/b", "prefix", null, null);
const oldKeyDash = oldKeyComputation(1, "/a-b", "prefix", null, null);
assertEquals(
oldKeyAB,
oldKeyDash,
"old code MUST have this collision (confirms the bug exists)"
);
console.log(" PASS: confirmed old code bug — /a/b and /a-b collided");
passed++;
}
// Test 10: /api/v1 and /api-v1 — old code collision, new code fixes it
{
const oldKey1 = oldKeyComputation(1, "/api/v1", "prefix", null, null);
const oldKey2 = oldKeyComputation(1, "/api-v1", "prefix", null, null);
assertEquals(
oldKey1,
oldKey2,
"old code collision for /api/v1 vs /api-v1"
);
const newKey1 = newKeyComputation(1, "/api/v1", "prefix", null, null);
const newKey2 = newKeyComputation(1, "/api-v1", "prefix", null, null);
assertEquals(
newKey1 !== newKey2,
true,
"new code must separate /api/v1 and /api-v1"
);
console.log(" PASS: collision fix — /api/v1 vs /api-v1");
passed++;
}
// Test 11: /app.v2 and /app/v2 and /app-v2 — three-way collision fixed
{
const a = newKeyComputation(1, "/app.v2", "prefix", null, null);
const b = newKeyComputation(1, "/app/v2", "prefix", null, null);
const c = newKeyComputation(1, "/app-v2", "prefix", null, null);
const keys = new Set([a, b, c]);
assertEquals(
keys.size,
3,
"three paths must produce three unique keys"
);
console.log(
" PASS: collision fix — three-way /app.v2, /app/v2, /app-v2"
);
passed++;
}
// ── Edge cases ───────────────────────────────────────────────────
// Test 12: same path in different resources — always separate
{
const key1 = newKeyComputation(1, "/api", "prefix", null, null);
const key2 = newKeyComputation(2, "/api", "prefix", null, null);
assertEquals(
key1 !== key2,
true,
"different resources with same path must have different keys"
);
console.log(" PASS: edge case — same path, different resources");
passed++;
}
// Test 13: same resource, different pathMatchType — separate keys
{
const exact = newKeyComputation(1, "/api", "exact", null, null);
const prefix = newKeyComputation(1, "/api", "prefix", null, null);
assertEquals(
exact !== prefix,
true,
"exact vs prefix must have different keys"
);
console.log(" PASS: edge case — same path, different match types");
passed++;
}
// Test 14: same resource and path, different rewrite config — separate keys
{
const noRewrite = newKeyComputation(1, "/api", "prefix", null, null);
const withRewrite = newKeyComputation(
1,
"/api",
"prefix",
"/backend",
"prefix"
);
assertEquals(
noRewrite !== withRewrite,
true,
"with vs without rewrite must have different keys"
);
console.log(" PASS: edge case — same path, different rewrite config");
passed++;
}
// Test 15: paths with special URL characters
{
const paths = ["/api?foo", "/api#bar", "/api%20baz", "/api+qux"];
const keys = new Set(
paths.map((p) => newKeyComputation(1, p, "prefix", null, null))
);
assertEquals(
keys.size,
paths.length,
"special URL chars must produce unique keys"
);
console.log(" PASS: edge case — special URL characters in paths");
passed++;
}
console.log(`\nAll ${passed} tests passed!`);
}
try {
runTests();
} catch (error) {
console.error("Test failed:", error);
process.exit(1);
}

View File

@@ -13,6 +13,26 @@ export function sanitize(input: string | null | undefined): string | undefined {
.replace(/^-|-$/g, "");
}
/**
* Encode a URL path into a collision-free alphanumeric string suitable for use
* in Traefik map keys.
*
* Unlike sanitize(), this preserves uniqueness by encoding each non-alphanumeric
* character as its hex code. Different paths always produce different outputs.
*
* encodePath("/api") => "2fapi"
* encodePath("/a/b") => "2fa2fb"
* encodePath("/a-b") => "2fa2db" (different from /a/b)
* encodePath("/") => "2f"
* encodePath(null) => ""
*/
export function encodePath(path: string | null | undefined): string {
if (!path) return "";
return path.replace(/[^a-zA-Z0-9]/g, (ch) => {
return ch.charCodeAt(0).toString(16);
});
}
export function validatePathRewriteConfig(
path: string | null,
pathMatchType: string | null,

163
server/lib/userOrg.ts Normal file
View File

@@ -0,0 +1,163 @@
import {
db,
Org,
orgs,
resources,
siteResources,
sites,
Transaction,
userOrgRoles,
userOrgs,
userResources,
userSiteResources,
userSites
} from "@server/db";
import { eq, and, inArray, ne, exists } from "drizzle-orm";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
export async function assignUserToOrg(
org: Org,
values: typeof userOrgs.$inferInsert,
roleIds: number[],
trx: Transaction | typeof db = db
) {
const uniqueRoleIds = [...new Set(roleIds)];
if (uniqueRoleIds.length === 0) {
throw new Error("assignUserToOrg requires at least one roleId");
}
const [userOrg] = await trx.insert(userOrgs).values(values).returning();
await trx.insert(userOrgRoles).values(
uniqueRoleIds.map((roleId) => ({
userId: userOrg.userId,
orgId: userOrg.orgId,
roleId
}))
);
// calculate if the user is in any other of the orgs before we count it as an add to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(
and(
eq(orgs.billingOrgId, org.billingOrgId),
ne(orgs.orgId, org.orgId)
)
);
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheUserIsStillIn = await trx
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userOrg.userId),
inArray(userOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) {
await usageService.add(org.orgId, FeatureId.USERS, 1, trx);
}
}
}
export async function removeUserFromOrg(
org: Org,
userId: string,
trx: Transaction | typeof db = db
) {
await trx
.delete(userOrgRoles)
.where(
and(
eq(userOrgRoles.userId, userId),
eq(userOrgRoles.orgId, org.orgId)
)
);
await trx
.delete(userOrgs)
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, org.orgId)));
await trx.delete(userResources).where(
and(
eq(userResources.userId, userId),
exists(
trx
.select()
.from(resources)
.where(
and(
eq(resources.resourceId, userResources.resourceId),
eq(resources.orgId, org.orgId)
)
)
)
)
);
await trx.delete(userSiteResources).where(
and(
eq(userSiteResources.userId, userId),
exists(
trx
.select()
.from(siteResources)
.where(
and(
eq(
siteResources.siteResourceId,
userSiteResources.siteResourceId
),
eq(siteResources.orgId, org.orgId)
)
)
)
)
);
await trx.delete(userSites).where(
and(
eq(userSites.userId, userId),
exists(
db
.select()
.from(sites)
.where(
and(
eq(sites.siteId, userSites.siteId),
eq(sites.orgId, org.orgId)
)
)
)
)
);
// calculate if the user is in any other of the orgs before we count it as an remove to the billing org
if (org.billingOrgId) {
const billingOrgs = await trx
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, org.billingOrgId));
const billingOrgIds = billingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheUserIsStillIn = await trx
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
inArray(userOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) {
await usageService.add(org.orgId, FeatureId.USERS, -1, trx);
}
}
}

View File

@@ -0,0 +1,36 @@
import { db, roles, userOrgRoles } from "@server/db";
import { and, eq } from "drizzle-orm";
/**
* Get all role IDs a user has in an organization.
* Returns empty array if the user has no roles in the org (callers must treat as no access).
*/
export async function getUserOrgRoleIds(
userId: string,
orgId: string
): Promise<number[]> {
const rows = await db
.select({ roleId: userOrgRoles.roleId })
.from(userOrgRoles)
.where(
and(
eq(userOrgRoles.userId, userId),
eq(userOrgRoles.orgId, orgId)
)
);
return rows.map((r) => r.roleId);
}
export async function getUserOrgRoles(
userId: string,
orgId: string
): Promise<{ roleId: number; roleName: string }[]> {
const rows = await db
.select({ roleId: userOrgRoles.roleId, roleName: roles.name })
.from(userOrgRoles)
.innerJoin(roles, eq(userOrgRoles.roleId, roles.roleId))
.where(
and(eq(userOrgRoles.userId, userId), eq(userOrgRoles.orgId, orgId))
);
return rows;
}

View File

@@ -12,6 +12,10 @@ export type LicenseStatus = {
isLicenseValid: boolean; // Is the license key valid?
hostId: string; // Host ID
tier?: LicenseKeyTier;
maxSites?: number;
usedSites?: number;
maxUsers?: number;
usedUsers?: number;
};
export type LicenseKeyCache = {
@@ -22,12 +26,14 @@ export type LicenseKeyCache = {
type?: LicenseKeyType;
tier?: LicenseKeyTier;
terminateAt?: Date;
quantity?: number;
quantity_2?: number;
};
export class License {
private serverSecret!: string;
constructor(private hostMeta: HostMeta) {}
constructor(private hostMeta: HostMeta) { }
public async check(): Promise<LicenseStatus> {
return {

View File

@@ -21,8 +21,7 @@ export async function getUserOrgs(
try {
const userOrganizations = await db
.select({
orgId: userOrgs.orgId,
roleId: userOrgs.roleId
orgId: userOrgs.orgId
})
.from(userOrgs)
.where(eq(userOrgs.userId, userId));

View File

@@ -17,6 +17,7 @@ export * from "./verifyAccessTokenAccess";
export * from "./requestTimeout";
export * from "./verifyClientAccess";
export * from "./verifyUserHasAction";
export * from "./verifyUserCanSetUserOrgRoles";
export * from "./verifyUserIsServerAdmin";
export * from "./verifyIsLoggedInUser";
export * from "./verifyIsLoggedInUser";
@@ -24,8 +25,10 @@ export * from "./verifyClientAccess";
export * from "./integration";
export * from "./verifyUserHasAction";
export * from "./verifyApiKeyAccess";
export * from "./verifySiteProvisioningKeyAccess";
export * from "./verifyDomainAccess";
export * from "./verifyUserIsOrgOwner";
export * from "./verifySiteResourceAccess";
export * from "./logActionAudit";
export * from "./verifyOlmAccess";
export * from "./verifyLimits";

View File

@@ -1,6 +1,7 @@
export * from "./verifyApiKey";
export * from "./verifyApiKeyOrgAccess";
export * from "./verifyApiKeyHasAction";
export * from "./verifyApiKeyCanSetUserOrgRoles";
export * from "./verifyApiKeySiteAccess";
export * from "./verifyApiKeyResourceAccess";
export * from "./verifyApiKeyTargetAccess";
@@ -13,3 +14,5 @@ export * from "./verifyApiKeyIsRoot";
export * from "./verifyApiKeyApiKeyAccess";
export * from "./verifyApiKeyClientAccess";
export * from "./verifyApiKeySiteResourceAccess";
export * from "./verifyApiKeyIdpAccess";
export * from "./verifyApiKeyDomainAccess";

View File

@@ -0,0 +1,74 @@
import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
import { ActionsEnum } from "@server/auth/actions";
import { db } from "@server/db";
import { apiKeyActions } from "@server/db";
import { and, eq } from "drizzle-orm";
async function apiKeyHasAction(apiKeyId: string, actionId: ActionsEnum) {
const [row] = await db
.select()
.from(apiKeyActions)
.where(
and(
eq(apiKeyActions.apiKeyId, apiKeyId),
eq(apiKeyActions.actionId, actionId)
)
);
return !!row;
}
/**
* Allows setUserOrgRoles on the key, or both addUserRole and removeUserRole.
*/
export function verifyApiKeyCanSetUserOrgRoles() {
return async function (
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
if (!req.apiKey) {
return next(
createHttpError(
HttpCode.UNAUTHORIZED,
"API Key not authenticated"
)
);
}
const keyId = req.apiKey.apiKeyId;
if (await apiKeyHasAction(keyId, ActionsEnum.setUserOrgRoles)) {
return next();
}
const hasAdd = await apiKeyHasAction(keyId, ActionsEnum.addUserRole);
const hasRemove = await apiKeyHasAction(
keyId,
ActionsEnum.removeUserRole
);
if (hasAdd && hasRemove) {
return next();
}
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Key does not have permission perform this action"
)
);
} catch (error) {
logger.error("Error verifying API key set user org roles:", error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying key action access"
)
);
}
};
}

View File

@@ -0,0 +1,90 @@
import { Request, Response, NextFunction } from "express";
import { db, domains, orgDomains, apiKeyOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
export async function verifyApiKeyDomainAccess(
req: Request,
res: Response,
next: NextFunction
) {
try {
const apiKey = req.apiKey;
const domainId =
req.params.domainId || req.body.domainId || req.query.domainId;
const orgId = req.params.orgId;
if (!apiKey) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Key not authenticated")
);
}
if (!domainId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid domain ID")
);
}
if (apiKey.isRoot) {
// Root keys can access any domain in any org
return next();
}
// Verify domain exists and belongs to the organization
const [domain] = await db
.select()
.from(domains)
.innerJoin(orgDomains, eq(orgDomains.domainId, domains.domainId))
.where(
and(
eq(orgDomains.domainId, domainId),
eq(orgDomains.orgId, orgId)
)
)
.limit(1);
if (!domain) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Domain with ID ${domainId} not found in organization ${orgId}`
)
);
}
// Verify the API key has access to this organization
if (!req.apiKeyOrg) {
const apiKeyOrgRes = await db
.select()
.from(apiKeyOrg)
.where(
and(
eq(apiKeyOrg.apiKeyId, apiKey.apiKeyId),
eq(apiKeyOrg.orgId, orgId)
)
)
.limit(1);
req.apiKeyOrg = apiKeyOrgRes[0];
}
if (!req.apiKeyOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Key does not have access to this organization"
)
);
}
return next();
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying domain access"
)
);
}
}

View File

@@ -0,0 +1,88 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { idp, idpOrg, apiKeyOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
export async function verifyApiKeyIdpAccess(
req: Request,
res: Response,
next: NextFunction
) {
try {
const apiKey = req.apiKey;
const idpId = req.params.idpId || req.body.idpId || req.query.idpId;
const orgId = req.params.orgId;
if (!apiKey) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Key not authenticated")
);
}
if (!orgId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid organization ID")
);
}
if (!idpId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid IDP ID")
);
}
if (apiKey.isRoot) {
// Root keys can access any IDP in any org
return next();
}
const [idpRes] = await db
.select()
.from(idp)
.innerJoin(idpOrg, eq(idp.idpId, idpOrg.idpId))
.where(and(eq(idp.idpId, idpId), eq(idpOrg.orgId, orgId)))
.limit(1);
if (!idpRes || !idpRes.idp || !idpRes.idpOrg) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`IdP with ID ${idpId} not found for organization ${orgId}`
)
);
}
if (!req.apiKeyOrg) {
const apiKeyOrgRes = await db
.select()
.from(apiKeyOrg)
.where(
and(
eq(apiKeyOrg.apiKeyId, apiKey.apiKeyId),
eq(apiKeyOrg.orgId, idpRes.idpOrg.orgId)
)
);
req.apiKeyOrg = apiKeyOrgRes[0];
}
if (!req.apiKeyOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Key does not have access to this organization"
)
);
}
return next();
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying IDP access"
)
);
}
}

View File

@@ -4,7 +4,6 @@ import { apiKeyOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
export async function verifyApiKeyOrgAccess(
req: Request,

View File

@@ -23,9 +23,14 @@ export async function verifyApiKeyRoleAccess(
);
}
const { roleIds } = req.body;
const allRoleIds =
roleIds || (isNaN(singleRoleId) ? [] : [singleRoleId]);
let allRoleIds: number[] = [];
if (!isNaN(singleRoleId)) {
// If roleId is provided in URL params, query params, or body (single), use it exclusively
allRoleIds = [singleRoleId];
} else if (req.body?.roleIds) {
// Only use body.roleIds if no single roleId was provided
allRoleIds = req.body.roleIds;
}
if (allRoleIds.length === 0) {
return next();

View File

@@ -6,6 +6,7 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { canUserAccessResource } from "@server/auth/canUserAccessResource";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export async function verifyAccessTokenAccess(
req: Request,
@@ -93,7 +94,10 @@ export async function verifyAccessTokenAccess(
)
);
} else {
req.userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleIds = await getUserOrgRoleIds(
req.userOrg.userId,
resource[0].orgId!
);
req.userOrgId = resource[0].orgId!;
}
@@ -118,7 +122,7 @@ export async function verifyAccessTokenAccess(
const resourceAllowed = await canUserAccessResource({
userId,
resourceId,
roleId: req.userOrgRoleId!
roleIds: req.userOrgRoleIds ?? []
});
if (!resourceAllowed) {

View File

@@ -1,10 +1,11 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { roles, userOrgs } from "@server/db";
import { and, eq } from "drizzle-orm";
import { and, eq, inArray } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export async function verifyAdmin(
req: Request,
@@ -62,13 +63,29 @@ export async function verifyAdmin(
}
}
const userRole = await db
req.userOrgRoleIds = await getUserOrgRoleIds(req.userOrg.userId, orgId!);
if (req.userOrgRoleIds.length === 0) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have Admin access"
)
);
}
const userAdminRoles = await db
.select()
.from(roles)
.where(eq(roles.roleId, req.userOrg.roleId))
.where(
and(
inArray(roles.roleId, req.userOrgRoleIds),
eq(roles.isAdmin, true)
)
)
.limit(1);
if (userRole.length === 0 || !userRole[0].isAdmin) {
if (userAdminRoles.length === 0) {
return next(
createHttpError(
HttpCode.FORBIDDEN,

View File

@@ -1,10 +1,11 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { userOrgs, apiKeys, apiKeyOrg } from "@server/db";
import { and, eq, or } from "drizzle-orm";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export async function verifyApiKeyAccess(
req: Request,
@@ -103,8 +104,10 @@ export async function verifyApiKeyAccess(
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgRoleIds = await getUserOrgRoleIds(
req.userOrg.userId,
orgId
);
return next();
} catch (error) {

View File

@@ -1,11 +1,12 @@
import { Request, Response, NextFunction } from "express";
import { Client, db } from "@server/db";
import { userOrgs, clients, roleClients, userClients } from "@server/db";
import { and, eq } from "drizzle-orm";
import { and, eq, inArray } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import logger from "@server/logger";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export async function verifyClientAccess(
req: Request,
@@ -113,21 +114,30 @@ export async function verifyClientAccess(
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgRoleIds = await getUserOrgRoleIds(
req.userOrg.userId,
client.orgId
);
req.userOrgId = client.orgId;
// Check role-based site access first
const [roleClientAccess] = await db
.select()
.from(roleClients)
.where(
and(
eq(roleClients.clientId, client.clientId),
eq(roleClients.roleId, userOrgRoleId)
)
)
.limit(1);
// Check role-based client access (any of user's roles)
const roleClientAccessList =
(req.userOrgRoleIds?.length ?? 0) > 0
? await db
.select()
.from(roleClients)
.where(
and(
eq(roleClients.clientId, client.clientId),
inArray(
roleClients.roleId,
req.userOrgRoleIds!
)
)
)
.limit(1)
: [];
const [roleClientAccess] = roleClientAccessList;
if (roleClientAccess) {
// User has access to the site through their role

View File

@@ -1,10 +1,11 @@
import { Request, Response, NextFunction } from "express";
import { db, domains, orgDomains } from "@server/db";
import { userOrgs, apiKeyOrg } from "@server/db";
import { userOrgs } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export async function verifyDomainAccess(
req: Request,
@@ -63,7 +64,7 @@ export async function verifyDomainAccess(
.where(
and(
eq(userOrgs.userId, userId),
eq(userOrgs.orgId, apiKeyOrg.orgId)
eq(userOrgs.orgId, orgId)
)
)
.limit(1);
@@ -97,8 +98,7 @@ export async function verifyDomainAccess(
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgRoleIds = await getUserOrgRoleIds(req.userOrg.userId, orgId);
return next();
} catch (error) {

View File

@@ -0,0 +1,43 @@
import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { usageService } from "@server/lib/billing/usageService";
import { build } from "@server/build";
export async function verifyLimits(
req: Request,
res: Response,
next: NextFunction
) {
if (build != "saas") {
return next();
}
const orgId = req.userOrgId || req.apiKeyOrg?.orgId || req.params.orgId;
if (!orgId) {
return next(); // its fine if we silently fail here because this is not critical to operation or security and its better user experience if we dont fail
}
try {
const reject = await usageService.checkLimitSet(orgId);
if (reject) {
return next(
createHttpError(
HttpCode.PAYMENT_REQUIRED,
"Organization has exceeded its usage limits. Please upgrade your plan or contact support."
)
);
}
return next();
} catch (e) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error checking limits"
)
);
}
}

View File

@@ -1,10 +1,11 @@
import { Request, Response, NextFunction } from "express";
import { db, orgs } from "@server/db";
import { db } from "@server/db";
import { userOrgs } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export async function verifyOrgAccess(
req: Request,
@@ -64,8 +65,8 @@ export async function verifyOrgAccess(
}
}
// User has access, attach the user's role to the request for potential future use
req.userOrgRoleId = req.userOrg.roleId;
// User has access, attach the user's role(s) to the request for potential future use
req.userOrgRoleIds = await getUserOrgRoleIds(req.userOrg.userId, orgId);
req.userOrgId = orgId;
return next();

View File

@@ -1,10 +1,11 @@
import { Request, Response, NextFunction } from "express";
import { db, Resource } from "@server/db";
import { resources, userOrgs, userResources, roleResources } from "@server/db";
import { and, eq } from "drizzle-orm";
import { and, eq, inArray } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export async function verifyResourceAccess(
req: Request,
@@ -107,20 +108,28 @@ export async function verifyResourceAccess(
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgRoleIds = await getUserOrgRoleIds(
req.userOrg.userId,
resource.orgId
);
req.userOrgId = resource.orgId;
const roleResourceAccess = await db
.select()
.from(roleResources)
.where(
and(
eq(roleResources.resourceId, resource.resourceId),
eq(roleResources.roleId, userOrgRoleId)
)
)
.limit(1);
const roleResourceAccess =
(req.userOrgRoleIds?.length ?? 0) > 0
? await db
.select()
.from(roleResources)
.where(
and(
eq(roleResources.resourceId, resource.resourceId),
inArray(
roleResources.roleId,
req.userOrgRoleIds!
)
)
)
.limit(1)
: [];
if (roleResourceAccess.length > 0) {
return next();

View File

@@ -6,6 +6,7 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export async function verifyRoleAccess(
req: Request,
@@ -23,8 +24,14 @@ export async function verifyRoleAccess(
);
}
const roleIds = req.body?.roleIds;
const allRoleIds = roleIds || (isNaN(singleRoleId) ? [] : [singleRoleId]);
let allRoleIds: number[] = [];
if (!isNaN(singleRoleId)) {
// If roleId is provided in URL params, query params, or body (single), use it exclusively
allRoleIds = [singleRoleId];
} else if (req.body?.roleIds) {
// Only use body.roleIds if no single roleId was provided
allRoleIds = req.body.roleIds;
}
if (allRoleIds.length === 0) {
return next();
@@ -93,7 +100,6 @@ export async function verifyRoleAccess(
}
if (!req.userOrg) {
// get the userORg
const userOrg = await db
.select()
.from(userOrgs)
@@ -103,7 +109,7 @@ export async function verifyRoleAccess(
.limit(1);
req.userOrg = userOrg[0];
req.userOrgRoleId = userOrg[0].roleId;
req.userOrgRoleIds = await getUserOrgRoleIds(userId, orgId!);
}
if (!req.userOrg) {

View File

@@ -1,10 +1,11 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { sites, Site, userOrgs, userSites, roleSites, roles } from "@server/db";
import { and, eq, or } from "drizzle-orm";
import { and, eq, inArray, or } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export async function verifySiteAccess(
req: Request,
@@ -112,21 +113,29 @@ export async function verifySiteAccess(
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgRoleIds = await getUserOrgRoleIds(
req.userOrg.userId,
site.orgId
);
req.userOrgId = site.orgId;
// Check role-based site access first
const roleSiteAccess = await db
.select()
.from(roleSites)
.where(
and(
eq(roleSites.siteId, site.siteId),
eq(roleSites.roleId, userOrgRoleId)
)
)
.limit(1);
// Check role-based site access first (any of user's roles)
const roleSiteAccess =
(req.userOrgRoleIds?.length ?? 0) > 0
? await db
.select()
.from(roleSites)
.where(
and(
eq(roleSites.siteId, site.siteId),
inArray(
roleSites.roleId,
req.userOrgRoleIds!
)
)
)
.limit(1)
: [];
if (roleSiteAccess.length > 0) {
// User's role has access to the site

View File

@@ -0,0 +1,131 @@
import { Request, Response, NextFunction } from "express";
import { db, userOrgs, siteProvisioningKeys, siteProvisioningKeyOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifySiteProvisioningKeyAccess(
req: Request,
res: Response,
next: NextFunction
) {
try {
const userId = req.user!.userId;
const siteProvisioningKeyId = req.params.siteProvisioningKeyId;
const orgId = req.params.orgId;
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
);
}
if (!orgId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid organization ID")
);
}
if (!siteProvisioningKeyId) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid key ID")
);
}
const [row] = await db
.select()
.from(siteProvisioningKeys)
.innerJoin(
siteProvisioningKeyOrg,
and(
eq(
siteProvisioningKeys.siteProvisioningKeyId,
siteProvisioningKeyOrg.siteProvisioningKeyId
),
eq(siteProvisioningKeyOrg.orgId, orgId)
)
)
.where(
eq(
siteProvisioningKeys.siteProvisioningKeyId,
siteProvisioningKeyId
)
)
.limit(1);
if (!row?.siteProvisioningKeys) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site provisioning key with ID ${siteProvisioningKeyId} not found`
)
);
}
if (!row.siteProvisioningKeyOrg.orgId) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Site provisioning key with ID ${siteProvisioningKeyId} does not have an organization ID`
)
);
}
if (!req.userOrg) {
const userOrgRole = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
eq(
userOrgs.orgId,
row.siteProvisioningKeyOrg.orgId
)
)
)
.limit(1);
req.userOrg = userOrgRole[0];
}
if (!req.userOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
)
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
return next();
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying site provisioning key access"
)
);
}
}

View File

@@ -1,11 +1,12 @@
import { Request, Response, NextFunction } from "express";
import { db, roleSiteResources, userOrgs, userSiteResources } from "@server/db";
import { siteResources } from "@server/db";
import { eq, and } from "drizzle-orm";
import { eq, and, inArray } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export async function verifySiteResourceAccess(
req: Request,
@@ -109,23 +110,34 @@ export async function verifySiteResourceAccess(
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgRoleIds = await getUserOrgRoleIds(
req.userOrg.userId,
siteResource.orgId
);
req.userOrgId = siteResource.orgId;
// Attach the siteResource to the request for use in the next middleware/route
req.siteResource = siteResource;
const roleResourceAccess = await db
.select()
.from(roleSiteResources)
.where(
and(
eq(roleSiteResources.siteResourceId, siteResourceIdNum),
eq(roleSiteResources.roleId, userOrgRoleId)
)
)
.limit(1);
const roleResourceAccess =
(req.userOrgRoleIds?.length ?? 0) > 0
? await db
.select()
.from(roleSiteResources)
.where(
and(
eq(
roleSiteResources.siteResourceId,
siteResourceIdNum
),
inArray(
roleSiteResources.roleId,
req.userOrgRoleIds!
)
)
)
.limit(1)
: [];
if (roleResourceAccess.length > 0) {
return next();

View File

@@ -6,6 +6,7 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { canUserAccessResource } from "../auth/canUserAccessResource";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { getUserOrgRoleIds } from "@server/lib/userOrgRoles";
export async function verifyTargetAccess(
req: Request,
@@ -99,7 +100,10 @@ export async function verifyTargetAccess(
)
);
} else {
req.userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleIds = await getUserOrgRoleIds(
req.userOrg.userId,
resource[0].orgId!
);
req.userOrgId = resource[0].orgId!;
}
@@ -126,7 +130,7 @@ export async function verifyTargetAccess(
const resourceAllowed = await canUserAccessResource({
userId,
resourceId,
roleId: req.userOrgRoleId!
roleIds: req.userOrgRoleIds ?? []
});
if (!resourceAllowed) {

View File

@@ -0,0 +1,54 @@
import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
import { ActionsEnum, checkUserActionPermission } from "@server/auth/actions";
/**
* Allows the new setUserOrgRoles action, or legacy permission pair addUserRole + removeUserRole.
*/
export function verifyUserCanSetUserOrgRoles() {
return async function (
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const canSet = await checkUserActionPermission(
ActionsEnum.setUserOrgRoles,
req
);
if (canSet) {
return next();
}
const canAdd = await checkUserActionPermission(
ActionsEnum.addUserRole,
req
);
const canRemove = await checkUserActionPermission(
ActionsEnum.removeUserRole,
req
);
if (canAdd && canRemove) {
return next();
}
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have permission perform this action"
)
);
} catch (error) {
logger.error("Error verifying set user org roles access:", error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying role access"
)
);
}
};
}

View File

@@ -12,7 +12,7 @@ export async function verifyUserInRole(
const roleId = parseInt(
req.params.roleId || req.body.roleId || req.query.roleId
);
const userRoleId = req.userOrgRoleId;
const userOrgRoleIds = req.userOrgRoleIds ?? [];
if (isNaN(roleId)) {
return next(
@@ -20,7 +20,7 @@ export async function verifyUserInRole(
);
}
if (!userRoleId) {
if (userOrgRoleIds.length === 0) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
@@ -29,7 +29,7 @@ export async function verifyUserInRole(
);
}
if (userRoleId !== roleId) {
if (!userOrgRoleIds.includes(roleId)) {
return next(
createHttpError(
HttpCode.FORBIDDEN,

View File

@@ -5,16 +5,20 @@ export const registry = new OpenAPIRegistry();
export enum OpenAPITags {
Site = "Site",
Org = "Organization",
Resource = "Resource",
PublicResource = "Public Resource",
PrivateResource = "Private Resource",
Role = "Role",
User = "User",
Invitation = "Invitation",
Target = "Target",
Invitation = "User Invitation",
Target = "Resource Target",
Rule = "Rule",
AccessToken = "Access Token",
Idp = "Identity Provider",
GlobalIdp = "Identity Provider (Global)",
OrgIdp = "Identity Provider (Organization Only)",
Client = "Client",
ApiKey = "API Key",
Domain = "Domain",
Blueprint = "Blueprint"
Blueprint = "Blueprint",
Ssh = "SSH",
Logs = "Logs"
}

View File

@@ -13,8 +13,16 @@
import { rateLimitService } from "#private/lib/rateLimit";
import { cleanup as wsCleanup } from "#private/routers/ws";
import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage";
import { flushConnectionLogToDb } from "#dynamic/routers/newt";
import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth";
import { stopPingAccumulator } from "@server/routers/newt/pingAccumulator";
async function cleanup() {
await stopPingAccumulator();
await flushBandwidthToDb();
await flushConnectionLogToDb();
await flushSiteBandwidthToDb();
await rateLimitService.cleanup();
await wsCleanup();

Some files were not shown because too many files have changed in this diff Show More