Compare commits

..

6 Commits

Author SHA1 Message Date
Owen
b4c01349d1 Merge branch 'dev' 2026-02-04 21:44:07 -08:00
Owen
e4d4c62833 Dont create newt sites with exit node or subnet 2026-02-02 18:19:13 -08:00
Owen
20ae903d7f Subscribed limits for domains is higher 2026-02-02 16:46:48 -08:00
MoweME
b0566d3c6f fix(i18n): correct German site terminology
Updates the German translation to use "Standort" (site) instead of "Seite" (page) for consistency with the site context.
2026-01-29 10:01:30 -08:00
MoweME
5dda8c384f fix(i18n): correct German translation strings
Corrects mistranslation of device timestamp labels and fixes product name reference in site tunnel settings.
2026-01-29 10:01:30 -08:00
Owen
cb569ff14d Properly insert PANGOLIN_SETUP_TOKEN into db
Fixes #2361
2026-01-28 15:03:31 -08:00
150 changed files with 2830 additions and 3832 deletions

View File

@@ -7,8 +7,8 @@ services:
POSTGRES_DB: postgres # Default database name POSTGRES_DB: postgres # Default database name
POSTGRES_USER: postgres # Default user POSTGRES_USER: postgres # Default user
POSTGRES_PASSWORD: password # Default password (change for production!) POSTGRES_PASSWORD: password # Default password (change for production!)
# volumes: volumes:
# - ./config/postgres:/var/lib/postgresql/data - ./config/postgres:/var/lib/postgresql/data
ports: ports:
- "5432:5432" # Map host port 5432 to container port 5432 - "5432:5432" # Map host port 5432 to container port 5432
restart: no restart: no

View File

@@ -1,14 +0,0 @@
import { defineConfig } from "drizzle-kit";
import path from "path";
const schema = [path.join("server", "db", "pg", "schema")];
export default defineConfig({
dialect: "postgresql",
schema: schema,
out: path.join("server", "migrations"),
verbose: true,
dbCredentials: {
url: process.env.DATABASE_URL as string
}
});

View File

@@ -97,7 +97,7 @@
"siteGeneralDescription": "Allgemeine Einstellungen für diesen Standort konfigurieren", "siteGeneralDescription": "Allgemeine Einstellungen für diesen Standort konfigurieren",
"siteSettingDescription": "Standorteinstellungen konfigurieren", "siteSettingDescription": "Standorteinstellungen konfigurieren",
"siteSetting": "{siteName} Einstellungen", "siteSetting": "{siteName} Einstellungen",
"siteNewtTunnel": "Neuer Standort (empfohlen)", "siteNewtTunnel": "Newt Standort (empfohlen)",
"siteNewtTunnelDescription": "Einfachster Weg, einen Einstiegspunkt in jedes Netzwerk zu erstellen. Keine zusätzliche Einrichtung.", "siteNewtTunnelDescription": "Einfachster Weg, einen Einstiegspunkt in jedes Netzwerk zu erstellen. Keine zusätzliche Einrichtung.",
"siteWg": "Einfacher WireGuard Tunnel", "siteWg": "Einfacher WireGuard Tunnel",
"siteWgDescription": "Verwende jeden WireGuard-Client, um einen Tunnel einzurichten. Manuelles NAT-Setup erforderlich.", "siteWgDescription": "Verwende jeden WireGuard-Client, um einen Tunnel einzurichten. Manuelles NAT-Setup erforderlich.",
@@ -107,7 +107,7 @@
"siteSeeAll": "Alle Standorte anzeigen", "siteSeeAll": "Alle Standorte anzeigen",
"siteTunnelDescription": "Legen Sie fest, wie Sie sich mit dem Standort verbinden möchten", "siteTunnelDescription": "Legen Sie fest, wie Sie sich mit dem Standort verbinden möchten",
"siteNewtCredentials": "Zugangsdaten", "siteNewtCredentials": "Zugangsdaten",
"siteNewtCredentialsDescription": "So wird sich die Seite mit dem Server authentifizieren", "siteNewtCredentialsDescription": "So wird sich der Standort mit dem Server authentifizieren",
"remoteNodeCredentialsDescription": "So wird sich der entfernte Node mit dem Server authentifizieren", "remoteNodeCredentialsDescription": "So wird sich der entfernte Node mit dem Server authentifizieren",
"siteCredentialsSave": "Anmeldedaten speichern", "siteCredentialsSave": "Anmeldedaten speichern",
"siteCredentialsSaveDescription": "Du kannst das nur einmal sehen. Stelle sicher, dass du es an einen sicheren Ort kopierst.", "siteCredentialsSaveDescription": "Du kannst das nur einmal sehen. Stelle sicher, dass du es an einen sicheren Ort kopierst.",
@@ -2503,7 +2503,7 @@
"deviceModel": "Gerätemodell", "deviceModel": "Gerätemodell",
"serialNumber": "Seriennummer", "serialNumber": "Seriennummer",
"hostname": "Hostname", "hostname": "Hostname",
"firstSeen": "Erster Blick", "firstSeen": "Zuerst gesehen",
"lastSeen": "Zuletzt gesehen", "lastSeen": "Zuletzt gesehen",
"biometricsEnabled": "Biometrie aktiviert", "biometricsEnabled": "Biometrie aktiviert",
"diskEncrypted": "Festplatte verschlüsselt", "diskEncrypted": "Festplatte verschlüsselt",

View File

@@ -18,8 +18,6 @@
"componentsMember": "You're a member of {count, plural, =0 {no organization} one {one organization} other {# organizations}}.", "componentsMember": "You're a member of {count, plural, =0 {no organization} one {one organization} other {# organizations}}.",
"componentsInvalidKey": "Invalid or expired license keys detected. Follow license terms to continue using all features.", "componentsInvalidKey": "Invalid or expired license keys detected. Follow license terms to continue using all features.",
"dismiss": "Dismiss", "dismiss": "Dismiss",
"subscriptionViolationMessage": "You're beyond your limits for your current plan. Correct the problem by removing sites, users, or other resources to stay within your plan.",
"subscriptionViolationViewBilling": "View billing",
"componentsLicenseViolation": "License Violation: This server is using {usedSites} sites which exceeds its licensed limit of {maxSites} sites. Follow license terms to continue using all features.", "componentsLicenseViolation": "License Violation: This server is using {usedSites} sites which exceeds its licensed limit of {maxSites} sites. Follow license terms to continue using all features.",
"componentsSupporterMessage": "Thank you for supporting Pangolin as a {tier}!", "componentsSupporterMessage": "Thank you for supporting Pangolin as a {tier}!",
"inviteErrorNotValid": "We're sorry, but it looks like the invite you're trying to access has not been accepted or is no longer valid.", "inviteErrorNotValid": "We're sorry, but it looks like the invite you're trying to access has not been accepted or is no longer valid.",
@@ -57,7 +55,7 @@
"siteDescription": "Create and manage sites to enable connectivity to private networks", "siteDescription": "Create and manage sites to enable connectivity to private networks",
"sitesBannerTitle": "Connect Any Network", "sitesBannerTitle": "Connect Any Network",
"sitesBannerDescription": "A site is a connection to a remote network that allows Pangolin to provide access to resources, whether public or private, to users anywhere. Install the site network connector (Newt) anywhere you can run a binary or container to establish the connection.", "sitesBannerDescription": "A site is a connection to a remote network that allows Pangolin to provide access to resources, whether public or private, to users anywhere. Install the site network connector (Newt) anywhere you can run a binary or container to establish the connection.",
"sitesBannerButtonText": "Install Site Connector", "sitesBannerButtonText": "Install Site",
"approvalsBannerTitle": "Approve or Deny Device Access", "approvalsBannerTitle": "Approve or Deny Device Access",
"approvalsBannerDescription": "Review and approve or deny device access requests from users. When device approvals are required, users must get admin approval before their devices can connect to your organization's resources.", "approvalsBannerDescription": "Review and approve or deny device access requests from users. When device approvals are required, users must get admin approval before their devices can connect to your organization's resources.",
"approvalsBannerButtonText": "Learn More", "approvalsBannerButtonText": "Learn More",
@@ -81,8 +79,8 @@
"siteConfirmCopy": "I have copied the config", "siteConfirmCopy": "I have copied the config",
"searchSitesProgress": "Search sites...", "searchSitesProgress": "Search sites...",
"siteAdd": "Add Site", "siteAdd": "Add Site",
"siteInstallNewt": "Install Site", "siteInstallNewt": "Install Newt",
"siteInstallNewtDescription": "Install the site connector for your system", "siteInstallNewtDescription": "Get Newt running on your system",
"WgConfiguration": "WireGuard Configuration", "WgConfiguration": "WireGuard Configuration",
"WgConfigurationDescription": "Use the following configuration to connect to the network", "WgConfigurationDescription": "Use the following configuration to connect to the network",
"operatingSystem": "Operating System", "operatingSystem": "Operating System",
@@ -1406,10 +1404,10 @@
"billingUsageLimitsOverview": "Usage Limits Overview", "billingUsageLimitsOverview": "Usage Limits Overview",
"billingMonitorUsage": "Monitor your usage against configured limits. If you need limits increased please contact us support@pangolin.net.", "billingMonitorUsage": "Monitor your usage against configured limits. If you need limits increased please contact us support@pangolin.net.",
"billingDataUsage": "Data Usage", "billingDataUsage": "Data Usage",
"billingSites": "Sites", "billingOnlineTime": "Site Online Time",
"billingUsers": "Users", "billingUsers": "Active Users",
"billingDomains": "Domains", "billingDomains": "Active Domains",
"billingRemoteExitNodes": "Remote Nodes", "billingRemoteExitNodes": "Active Self-hosted Nodes",
"billingNoLimitConfigured": "No limit configured", "billingNoLimitConfigured": "No limit configured",
"billingEstimatedPeriod": "Estimated Billing Period", "billingEstimatedPeriod": "Estimated Billing Period",
"billingIncludedUsage": "Included Usage", "billingIncludedUsage": "Included Usage",
@@ -1434,10 +1432,10 @@
"billingFailedToGetPortalUrl": "Failed to get portal URL", "billingFailedToGetPortalUrl": "Failed to get portal URL",
"billingPortalError": "Portal Error", "billingPortalError": "Portal Error",
"billingDataUsageInfo": "You're charged for all data transferred through your secure tunnels when connected to the cloud. This includes both incoming and outgoing traffic across all your sites. When you reach your limit, your sites will disconnect until you upgrade your plan or reduce usage. Data is not charged when using nodes.", "billingDataUsageInfo": "You're charged for all data transferred through your secure tunnels when connected to the cloud. This includes both incoming and outgoing traffic across all your sites. When you reach your limit, your sites will disconnect until you upgrade your plan or reduce usage. Data is not charged when using nodes.",
"billingSInfo": "How many sites you can use", "billingOnlineTimeInfo": "You're charged based on how long your sites stay connected to the cloud. For example, 44,640 minutes equals one site running 24/7 for a full month. When you reach your limit, your sites will disconnect until you upgrade your plan or reduce usage. Time is not charged when using nodes.",
"billingUsersInfo": "How many users you can use", "billingUsersInfo": "You're charged for each user in the organization. Billing is calculated daily based on the number of active user accounts in your org.",
"billingDomainInfo": "How many domains you can use", "billingDomainInfo": "You're charged for each domain in the organization. Billing is calculated daily based on the number of active domain accounts in your org.",
"billingRemoteExitNodesInfo": "How many remote nodes you can use", "billingRemoteExitNodesInfo": "You're charged for each managed Node in the organization. Billing is calculated daily based on the number of active managed Nodes in your org.",
"billingLicenseKeys": "License Keys", "billingLicenseKeys": "License Keys",
"billingLicenseKeysDescription": "Manage your license key subscriptions", "billingLicenseKeysDescription": "Manage your license key subscriptions",
"billingLicenseSubscription": "License Subscription", "billingLicenseSubscription": "License Subscription",
@@ -1446,6 +1444,7 @@
"billingQuantity": "Quantity", "billingQuantity": "Quantity",
"billingTotal": "total", "billingTotal": "total",
"billingModifyLicenses": "Modify License Subscription", "billingModifyLicenses": "Modify License Subscription",
"billingPricingCalculatorLink": "View Pricing Calculator",
"domainNotFound": "Domain Not Found", "domainNotFound": "Domain Not Found",
"domainNotFoundDescription": "This resource is disabled because the domain no longer exists our system. Please set a new domain for this resource.", "domainNotFoundDescription": "This resource is disabled because the domain no longer exists our system. Please set a new domain for this resource.",
"failed": "Failed", "failed": "Failed",
@@ -1522,32 +1521,6 @@
"resourcePortRequired": "Port number is required for non-HTTP resources", "resourcePortRequired": "Port number is required for non-HTTP resources",
"resourcePortNotAllowed": "Port number should not be set for HTTP resources", "resourcePortNotAllowed": "Port number should not be set for HTTP resources",
"billingPricingCalculatorLink": "Pricing Calculator", "billingPricingCalculatorLink": "Pricing Calculator",
"billingYourPlan": "Your Plan",
"billingViewOrModifyPlan": "View or modify your current plan",
"billingViewPlanDetails": "View Plan Details",
"billingUsageAndLimits": "Usage and Limits",
"billingViewUsageAndLimits": "View your plan's limits and current usage",
"billingCurrentUsage": "Current Usage",
"billingMaximumLimits": "Maximum Limits",
"billingRemoteNodes": "Remote Nodes",
"billingUnlimited": "Unlimited",
"billingPaidLicenseKeys": "Paid License Keys",
"billingManageLicenseSubscription": "Manage your subscription for paid self-hosted license keys",
"billingCurrentKeys": "Current Keys",
"billingModifyCurrentPlan": "Modify Current Plan",
"billingConfirmUpgrade": "Confirm Upgrade",
"billingConfirmDowngrade": "Confirm Downgrade",
"billingConfirmUpgradeDescription": "You are about to upgrade your plan. Review the new limits and pricing below.",
"billingConfirmDowngradeDescription": "You are about to downgrade your plan. Review the new limits and pricing below.",
"billingPlanIncludes": "Plan Includes",
"billingProcessing": "Processing...",
"billingConfirmUpgradeButton": "Confirm Upgrade",
"billingConfirmDowngradeButton": "Confirm Downgrade",
"billingLimitViolationWarning": "Usage Exceeds New Plan Limits",
"billingLimitViolationDescription": "Your current usage exceeds the limits of this plan. After downgrading, all actions will be disabled until you reduce usage within the new limits. Please review the features below that are currently over the limits. Limits in violation:",
"billingFeatureLossWarning": "Feature Availability Notice",
"billingFeatureLossDescription": "By downgrading, features not available in the new plan will be automatically disabled. Some settings and configurations may be lost. Please review the pricing matrix to understand which features will no longer be available.",
"billingUsageExceedsLimit": "Current usage ({current}) exceeds limit ({limit})",
"signUpTerms": { "signUpTerms": {
"IAgreeToThe": "I agree to the", "IAgreeToThe": "I agree to the",
"termsOfService": "terms of service", "termsOfService": "terms of service",
@@ -1572,8 +1545,8 @@
"addressDescription": "The internal address of the client. Must fall within the organization's subnet.", "addressDescription": "The internal address of the client. Must fall within the organization's subnet.",
"selectSites": "Select sites", "selectSites": "Select sites",
"sitesDescription": "The client will have connectivity to the selected sites", "sitesDescription": "The client will have connectivity to the selected sites",
"clientInstallOlm": "Install Machine Client", "clientInstallOlm": "Install Olm",
"clientInstallOlmDescription": "Install the machine client for your system", "clientInstallOlmDescription": "Get Olm running on your system",
"clientOlmCredentials": "Credentials", "clientOlmCredentials": "Credentials",
"clientOlmCredentialsDescription": "This is how the client will authenticate with the server", "clientOlmCredentialsDescription": "This is how the client will authenticate with the server",
"olmEndpoint": "Endpoint", "olmEndpoint": "Endpoint",
@@ -1962,13 +1935,6 @@
"orgAuthBackToSignIn": "Back to standard sign in", "orgAuthBackToSignIn": "Back to standard sign in",
"orgAuthNoAccount": "Don't have an account?", "orgAuthNoAccount": "Don't have an account?",
"subscriptionRequiredToUse": "A subscription is required to use this feature.", "subscriptionRequiredToUse": "A subscription is required to use this feature.",
"mustUpgradeToUse": "You must upgrade your subscription to use this feature.",
"subscriptionRequiredTierToUse": "This feature requires <tierLink>{tier}</tierLink> or higher.",
"upgradeToTierToUse": "Upgrade to <tierLink>{tier}</tierLink> or higher to use this feature.",
"subscriptionTierTier1": "Home",
"subscriptionTierTier2": "Team",
"subscriptionTierTier3": "Business",
"subscriptionTierEnterprise": "Enterprise",
"idpDisabled": "Identity providers are disabled.", "idpDisabled": "Identity providers are disabled.",
"orgAuthPageDisabled": "Organization auth page is disabled.", "orgAuthPageDisabled": "Organization auth page is disabled.",
"domainRestartedDescription": "Domain verification restarted successfully", "domainRestartedDescription": "Domain verification restarted successfully",
@@ -2281,7 +2247,6 @@
"actionLogsDescription": "View a history of actions performed in this organization", "actionLogsDescription": "View a history of actions performed in this organization",
"accessLogsDescription": "View access auth requests for resources in this organization", "accessLogsDescription": "View access auth requests for resources in this organization",
"licenseRequiredToUse": "An Enterprise license is required to use this feature.", "licenseRequiredToUse": "An Enterprise license is required to use this feature.",
"ossEnterpriseEditionRequired": "The <enterpriseEditionLink>Enterprise Edition</enterpriseEditionLink> is required to use this feature.",
"certResolver": "Certificate Resolver", "certResolver": "Certificate Resolver",
"certResolverDescription": "Select the certificate resolver to use for this resource.", "certResolverDescription": "Select the certificate resolver to use for this resource.",
"selectCertResolver": "Select Certificate Resolver", "selectCertResolver": "Select Certificate Resolver",

View File

@@ -14,12 +14,10 @@
"dev": "NODE_ENV=development ENVIRONMENT=dev tsx watch server/index.ts", "dev": "NODE_ENV=development ENVIRONMENT=dev tsx watch server/index.ts",
"dev:check": "npx tsc --noEmit && npm run format:check", "dev:check": "npx tsc --noEmit && npm run format:check",
"dev:setup": "cp config/config.example.yml config/config.yml && npm run set:oss && npm run set:sqlite && npm run db:generate && npm run db:sqlite:push", "dev:setup": "cp config/config.example.yml config/config.yml && npm run set:oss && npm run set:sqlite && npm run db:generate && npm run db:sqlite:push",
"db:pg:generate": "drizzle-kit generate --config=./drizzle.pg.config.ts", "db:generate": "drizzle-kit generate --config=./drizzle.config.ts",
"db:sqlite:generate": "drizzle-kit generate --config=./drizzle.sqlite.config.ts",
"db:pg:push": "npx tsx server/db/pg/migrate.ts", "db:pg:push": "npx tsx server/db/pg/migrate.ts",
"db:sqlite:push": "npx tsx server/db/sqlite/migrate.ts", "db:sqlite:push": "npx tsx server/db/sqlite/migrate.ts",
"db:pg:studio": "drizzle-kit studio --config=./drizzle.pg.config.ts", "db:studio": "drizzle-kit studio --config=./drizzle.config.ts",
"db:sqlite:studio": "drizzle-kit studio --config=./drizzle.sqlite.config.ts",
"db:clear-migrations": "rm -rf server/migrations", "db:clear-migrations": "rm -rf server/migrations",
"set:oss": "echo 'export const build = \"oss\" as \"saas\" | \"enterprise\" | \"oss\";' > server/build.ts && cp tsconfig.oss.json tsconfig.json", "set:oss": "echo 'export const build = \"oss\" as \"saas\" | \"enterprise\" | \"oss\";' > server/build.ts && cp tsconfig.oss.json tsconfig.json",
"set:saas": "echo 'export const build = \"saas\" as \"saas\" | \"enterprise\" | \"oss\";' > server/build.ts && cp tsconfig.saas.json tsconfig.json", "set:saas": "echo 'export const build = \"saas\" as \"saas\" | \"enterprise\" | \"oss\";' > server/build.ts && cp tsconfig.saas.json tsconfig.json",

View File

@@ -82,14 +82,11 @@ export const subscriptions = pgTable("subscriptions", {
canceledAt: bigint("canceledAt", { mode: "number" }), canceledAt: bigint("canceledAt", { mode: "number" }),
createdAt: bigint("createdAt", { mode: "number" }).notNull(), createdAt: bigint("createdAt", { mode: "number" }).notNull(),
updatedAt: bigint("updatedAt", { mode: "number" }), updatedAt: bigint("updatedAt", { mode: "number" }),
version: integer("version"), billingCycleAnchor: bigint("billingCycleAnchor", { mode: "number" })
billingCycleAnchor: bigint("billingCycleAnchor", { mode: "number" }),
type: varchar("type", { length: 50 }) // tier1, tier2, tier3, or license
}); });
export const subscriptionItems = pgTable("subscriptionItems", { export const subscriptionItems = pgTable("subscriptionItems", {
subscriptionItemId: serial("subscriptionItemId").primaryKey(), subscriptionItemId: serial("subscriptionItemId").primaryKey(),
stripeSubscriptionItemId: varchar("stripeSubscriptionItemId", { length: 255 }),
subscriptionId: varchar("subscriptionId", { length: 255 }) subscriptionId: varchar("subscriptionId", { length: 255 })
.notNull() .notNull()
.references(() => subscriptions.subscriptionId, { .references(() => subscriptions.subscriptionId, {
@@ -97,7 +94,6 @@ export const subscriptionItems = pgTable("subscriptionItems", {
}), }),
planId: varchar("planId", { length: 255 }).notNull(), planId: varchar("planId", { length: 255 }).notNull(),
priceId: varchar("priceId", { length: 255 }), priceId: varchar("priceId", { length: 255 }),
featureId: varchar("featureId", { length: 255 }),
meterId: varchar("meterId", { length: 255 }), meterId: varchar("meterId", { length: 255 }),
unitAmount: real("unitAmount"), unitAmount: real("unitAmount"),
tiers: text("tiers"), tiers: text("tiers"),
@@ -140,7 +136,6 @@ export const limits = pgTable("limits", {
}) })
.notNull(), .notNull(),
value: real("value"), value: real("value"),
override: boolean("override").default(false),
description: text("description") description: text("description")
}); });

View File

@@ -70,9 +70,7 @@ export const subscriptions = sqliteTable("subscriptions", {
canceledAt: integer("canceledAt"), canceledAt: integer("canceledAt"),
createdAt: integer("createdAt").notNull(), createdAt: integer("createdAt").notNull(),
updatedAt: integer("updatedAt"), updatedAt: integer("updatedAt"),
version: integer("version"), billingCycleAnchor: integer("billingCycleAnchor")
billingCycleAnchor: integer("billingCycleAnchor"),
type: text("type") // tier1, tier2, tier3, or license
}); });
export const subscriptionItems = sqliteTable("subscriptionItems", { export const subscriptionItems = sqliteTable("subscriptionItems", {
@@ -86,7 +84,6 @@ export const subscriptionItems = sqliteTable("subscriptionItems", {
}), }),
planId: text("planId").notNull(), planId: text("planId").notNull(),
priceId: text("priceId"), priceId: text("priceId"),
featureId: text("featureId"),
meterId: text("meterId"), meterId: text("meterId"),
unitAmount: real("unitAmount"), unitAmount: real("unitAmount"),
tiers: text("tiers"), tiers: text("tiers"),
@@ -129,7 +126,6 @@ export const limits = sqliteTable("limits", {
}) })
.notNull(), .notNull(),
value: real("value"), value: real("value"),
override: integer("override", { mode: "boolean" }).default(false),
description: text("description") description: text("description")
}); });

View File

@@ -105,13 +105,11 @@ function getOpenApiDocumentation() {
servers: [{ url: "/v1" }] servers: [{ url: "/v1" }]
}); });
if (!process.env.DISABLE_GEN_OPENAPI) { // convert to yaml and save to file
// convert to yaml and save to file const outputPath = path.join(APP_PATH, "openapi.yaml");
const outputPath = path.join(APP_PATH, "openapi.yaml"); const yamlOutput = yaml.dump(generated);
const yamlOutput = yaml.dump(generated); fs.writeFileSync(outputPath, yamlOutput, "utf8");
fs.writeFileSync(outputPath, yamlOutput, "utf8"); logger.info(`OpenAPI documentation saved to ${outputPath}`);
logger.info(`OpenAPI documentation saved to ${outputPath}`);
}
return generated; return generated;
} }

View File

@@ -1,41 +1,30 @@
import Stripe from "stripe";
export enum FeatureId { export enum FeatureId {
SITE_UPTIME = "siteUptime",
USERS = "users", USERS = "users",
SITES = "sites",
EGRESS_DATA_MB = "egressDataMb", EGRESS_DATA_MB = "egressDataMb",
DOMAINS = "domains", DOMAINS = "domains",
REMOTE_EXIT_NODES = "remoteExitNodes", REMOTE_EXIT_NODES = "remoteExitNodes"
TIER1 = "tier1"
} }
export async function getFeatureDisplayName(featureId: FeatureId): Promise<string> { export const FeatureMeterIds: Record<FeatureId, string> = {
switch (featureId) { [FeatureId.SITE_UPTIME]: "mtr_61Srrej5wUJuiTWgo41D3Ee2Ir7WmDLU",
case FeatureId.USERS: [FeatureId.USERS]: "mtr_61SrreISyIWpwUNGR41D3Ee2Ir7WmQro",
return "Users"; [FeatureId.EGRESS_DATA_MB]: "mtr_61Srreh9eWrExDSCe41D3Ee2Ir7Wm5YW",
case FeatureId.SITES: [FeatureId.DOMAINS]: "mtr_61Ss9nIKDNMw0LDRU41D3Ee2Ir7WmRPU",
return "Sites"; [FeatureId.REMOTE_EXIT_NODES]: "mtr_61T86UXnfxTVXy9sD41D3Ee2Ir7WmFTE"
case FeatureId.EGRESS_DATA_MB:
return "Egress Data (MB)";
case FeatureId.DOMAINS:
return "Domains";
case FeatureId.REMOTE_EXIT_NODES:
return "Remote Exit Nodes";
case FeatureId.TIER1:
return "Home Lab";
default:
return featureId;
}
}
// this is from the old system
export const FeatureMeterIds: Partial<Record<FeatureId, string>> = { // right now we are not charging for any data
// [FeatureId.EGRESS_DATA_MB]: "mtr_61Srreh9eWrExDSCe41D3Ee2Ir7Wm5YW"
}; };
export const FeatureMeterIdsSandbox: Partial<Record<FeatureId, string>> = { export const FeatureMeterIdsSandbox: Record<FeatureId, string> = {
// [FeatureId.EGRESS_DATA_MB]: "mtr_test_61Snh2a2m6qome5Kv41DCpkOb237B3dQ" [FeatureId.SITE_UPTIME]: "mtr_test_61Snh3cees4w60gv841DCpkOb237BDEu",
[FeatureId.USERS]: "mtr_test_61Sn5fLtq1gSfRkyA41DCpkOb237B6au",
[FeatureId.EGRESS_DATA_MB]: "mtr_test_61Snh2a2m6qome5Kv41DCpkOb237B3dQ",
[FeatureId.DOMAINS]: "mtr_test_61SsA8qrdAlgPpFRQ41DCpkOb237BGts",
[FeatureId.REMOTE_EXIT_NODES]: "mtr_test_61T86Vqmwa3D9ra3341DCpkOb237B94K"
}; };
export function getFeatureMeterId(featureId: FeatureId): string | undefined { export function getFeatureMeterId(featureId: FeatureId): string {
if ( if (
process.env.ENVIRONMENT == "prod" && process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true" process.env.SANDBOX_MODE !== "true"
@@ -54,81 +43,45 @@ export function getFeatureIdByMetricId(
)?.[0]; )?.[0];
} }
export type FeaturePriceSet = Partial<Record<FeatureId, string>>; export type FeaturePriceSet = {
[key in Exclude<FeatureId, FeatureId.DOMAINS>]: string;
export const homeLabFeaturePriceSet: FeaturePriceSet = { } & {
[FeatureId.TIER1]: "price_1SzVE3D3Ee2Ir7Wm6wT5Dl3G" [FeatureId.DOMAINS]?: string; // Optional since domains are not billed
}; };
export const homeLabFeaturePriceSetSandbox: FeaturePriceSet = { export const standardFeaturePriceSet: FeaturePriceSet = {
[FeatureId.TIER1]: "price_1SxgpPDCpkOb237Bfo4rIsoT" // Free tier matches the freeLimitSet
[FeatureId.SITE_UPTIME]: "price_1RrQc4D3Ee2Ir7WmaJGZ3MtF",
[FeatureId.USERS]: "price_1RrQeJD3Ee2Ir7WmgveP3xea",
[FeatureId.EGRESS_DATA_MB]: "price_1RrQXFD3Ee2Ir7WmvGDlgxQk",
// [FeatureId.DOMAINS]: "price_1Rz3tMD3Ee2Ir7Wm5qLeASzC",
[FeatureId.REMOTE_EXIT_NODES]: "price_1S46weD3Ee2Ir7Wm94KEHI4h"
}; };
export function getHomeLabFeaturePriceSet(): FeaturePriceSet { export const standardFeaturePriceSetSandbox: FeaturePriceSet = {
// Free tier matches the freeLimitSet
[FeatureId.SITE_UPTIME]: "price_1RefFBDCpkOb237BPrKZ8IEU",
[FeatureId.USERS]: "price_1ReNa4DCpkOb237Bc67G5muF",
[FeatureId.EGRESS_DATA_MB]: "price_1Rfp9LDCpkOb237BwuN5Oiu0",
// [FeatureId.DOMAINS]: "price_1Ryi88DCpkOb237B2D6DM80b",
[FeatureId.REMOTE_EXIT_NODES]: "price_1RyiZvDCpkOb237BXpmoIYJL"
};
export function getStandardFeaturePriceSet(): FeaturePriceSet {
if ( if (
process.env.ENVIRONMENT == "prod" && process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true" process.env.SANDBOX_MODE !== "true"
) { ) {
return homeLabFeaturePriceSet; return standardFeaturePriceSet;
} else { } else {
return homeLabFeaturePriceSetSandbox; return standardFeaturePriceSetSandbox;
} }
} }
export const tier2FeaturePriceSet: FeaturePriceSet = { export function getLineItems(
[FeatureId.USERS]: "price_1SzVCcD3Ee2Ir7Wmn6U3KvPN" featurePriceSet: FeaturePriceSet
}; ): Stripe.Checkout.SessionCreateParams.LineItem[] {
return Object.entries(featurePriceSet).map(([featureId, priceId]) => ({
export const tier2FeaturePriceSetSandbox: FeaturePriceSet = { price: priceId
[FeatureId.USERS]: "price_1SxaEHDCpkOb237BD9lBkPiR" }));
};
export function getStarterFeaturePriceSet(): FeaturePriceSet {
if (
process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true"
) {
return tier2FeaturePriceSet;
} else {
return tier2FeaturePriceSetSandbox;
}
}
export const tier3FeaturePriceSet: FeaturePriceSet = {
[FeatureId.USERS]: "price_1SzVDKD3Ee2Ir7WmPtOKNusv"
};
export const tier3FeaturePriceSetSandbox: FeaturePriceSet = {
[FeatureId.USERS]: "price_1SxaEODCpkOb237BiXdCBSfs"
};
export function getScaleFeaturePriceSet(): FeaturePriceSet {
if (
process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true"
) {
return tier3FeaturePriceSet;
} else {
return tier3FeaturePriceSetSandbox;
}
}
export function getFeatureIdByPriceId(priceId: string): FeatureId | undefined {
// Check all feature price sets
const allPriceSets = [
getHomeLabFeaturePriceSet(),
getStarterFeaturePriceSet(),
getScaleFeaturePriceSet()
];
for (const priceSet of allPriceSets) {
const entry = (Object.entries(priceSet) as [FeatureId, string][]).find(
([_, price]) => price === priceId
);
if (entry) {
return entry[0];
}
}
return undefined;
} }

View File

@@ -1,25 +0,0 @@
import Stripe from "stripe";
import { FeatureId, FeaturePriceSet } from "./features";
import { usageService } from "./usageService";
export async function getLineItems(
featurePriceSet: FeaturePriceSet,
orgId: string,
): Promise<Stripe.Checkout.SessionCreateParams.LineItem[]> {
const users = await usageService.getUsage(orgId, FeatureId.USERS);
return Object.entries(featurePriceSet).map(([featureId, priceId]) => {
let quantity: number | undefined;
if (featureId === FeatureId.USERS) {
quantity = users?.instantaneousValue || 1;
} else if (featureId === FeatureId.TIER1) {
quantity = 1;
}
return {
price: priceId,
quantity: quantity
};
});
}

View File

@@ -1,67 +1,50 @@
import { FeatureId } from "./features"; import { FeatureId } from "./features";
export type LimitSet = Partial<{ export type LimitSet = {
[key in FeatureId]: { [key in FeatureId]: {
value: number | null; // null indicates no limit value: number | null; // null indicates no limit
description?: string; description?: string;
}; };
}>; };
export const sandboxLimitSet: LimitSet = { export const sandboxLimitSet: LimitSet = {
[FeatureId.SITE_UPTIME]: { value: 2880, description: "Sandbox limit" }, // 1 site up for 2 days
[FeatureId.USERS]: { value: 1, description: "Sandbox limit" }, [FeatureId.USERS]: { value: 1, description: "Sandbox limit" },
[FeatureId.SITES]: { value: 1, description: "Sandbox limit" }, [FeatureId.EGRESS_DATA_MB]: { value: 1000, description: "Sandbox limit" }, // 1 GB
[FeatureId.DOMAINS]: { value: 0, description: "Sandbox limit" }, [FeatureId.DOMAINS]: { value: 0, description: "Sandbox limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" }, [FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" }
}; };
export const freeLimitSet: LimitSet = { export const freeLimitSet: LimitSet = {
[FeatureId.USERS]: { value: 5, description: "Starter limit" }, [FeatureId.SITE_UPTIME]: { value: 46080, description: "Free tier limit" }, // 1 site up for 32 days
[FeatureId.SITES]: { value: 5, description: "Starter limit" }, [FeatureId.USERS]: { value: 3, description: "Free tier limit" },
[FeatureId.DOMAINS]: { value: 5, description: "Starter limit" }, [FeatureId.EGRESS_DATA_MB]: {
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Starter limit" }, value: 25000,
description: "Free tier limit"
}, // 25 GB
[FeatureId.DOMAINS]: { value: 3, description: "Free tier limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Free tier limit" }
}; };
export const tier1LimitSet: LimitSet = { export const subscribedLimitSet: LimitSet = {
[FeatureId.USERS]: { value: 7, description: "Home limit" }, [FeatureId.SITE_UPTIME]: {
[FeatureId.SITES]: { value: 10, description: "Home limit" }, value: 2232000,
[FeatureId.DOMAINS]: { value: 10, description: "Home limit" }, description: "Contact us to increase soft limit."
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Home limit" }, }, // 50 sites up for 31 days
};
export const tier2LimitSet: LimitSet = {
[FeatureId.USERS]: { [FeatureId.USERS]: {
value: 100, value: 150,
description: "Team limit" description: "Contact us to increase soft limit."
},
[FeatureId.SITES]: {
value: 50,
description: "Team limit"
}, },
[FeatureId.EGRESS_DATA_MB]: {
value: 12000000,
description: "Contact us to increase soft limit."
}, // 12000 GB
[FeatureId.DOMAINS]: { [FeatureId.DOMAINS]: {
value: 50,
description: "Team limit"
},
[FeatureId.REMOTE_EXIT_NODES]: {
value: 3,
description: "Team limit"
},
};
export const tier3LimitSet: LimitSet = {
[FeatureId.USERS]: {
value: 500,
description: "Business limit"
},
[FeatureId.SITES]: {
value: 250, value: 250,
description: "Business limit" description: "Contact us to increase soft limit."
},
[FeatureId.DOMAINS]: {
value: 100,
description: "Business limit"
}, },
[FeatureId.REMOTE_EXIT_NODES]: { [FeatureId.REMOTE_EXIT_NODES]: {
value: 20, value: 5,
description: "Business limit" description: "Contact us to increase soft limit."
}, }
}; };

View File

@@ -2,7 +2,6 @@ import { db, limits } from "@server/db";
import { and, eq } from "drizzle-orm"; import { and, eq } from "drizzle-orm";
import { LimitSet } from "./limitSet"; import { LimitSet } from "./limitSet";
import { FeatureId } from "./features"; import { FeatureId } from "./features";
import logger from "@server/logger";
class LimitService { class LimitService {
async applyLimitSetToOrg(orgId: string, limitSet: LimitSet): Promise<void> { async applyLimitSetToOrg(orgId: string, limitSet: LimitSet): Promise<void> {
@@ -14,21 +13,6 @@ class LimitService {
for (const [featureId, entry] of limitEntries) { for (const [featureId, entry] of limitEntries) {
const limitId = `${orgId}-${featureId}`; const limitId = `${orgId}-${featureId}`;
const { value, description } = entry; const { value, description } = entry;
// get the limit first
const [limit] = await trx
.select()
.from(limits)
.where(eq(limits.limitId, limitId))
.limit(1);
// check if its overriden
if (limit && limit.override) {
logger.debug(
`Skipping limit ${limitId} for org ${orgId} since it is overridden...`
);
continue;
}
await trx await trx
.insert(limits) .insert(limits)
.values({ limitId, orgId, featureId, value, description }); .values({ limitId, orgId, featureId, value, description });

View File

@@ -1,50 +0,0 @@
import { Tier } from "@server/types/Tiers";
export enum TierFeature {
OrgOidc = "orgOidc",
LoginPageDomain = "loginPageDomain", // handle downgrade by removing custom domain
DeviceApprovals = "deviceApprovals", // handle downgrade by disabling device approvals
LoginPageBranding = "loginPageBranding", // handle downgrade by setting to default branding
LogExport = "logExport",
AccessLogs = "accessLogs", // set the retention period to none on downgrade
ActionLogs = "actionLogs", // set the retention period to none on downgrade
RotateCredentials = "rotateCredentials",
MaintencePage = "maintencePage", // handle downgrade
DevicePosture = "devicePosture",
TwoFactorEnforcement = "twoFactorEnforcement", // handle downgrade by setting to optional
SessionDurationPolicies = "sessionDurationPolicies", // handle downgrade by setting to default duration
PasswordExpirationPolicies = "passwordExpirationPolicies", // handle downgrade by setting to default duration
AutoProvisioning = "autoProvisioning" // handle downgrade by disabling auto provisioning
}
export const tierMatrix: Record<TierFeature, Tier[]> = {
[TierFeature.OrgOidc]: ["tier1", "tier2", "tier3", "enterprise"],
[TierFeature.LoginPageDomain]: ["tier1", "tier2", "tier3", "enterprise"],
[TierFeature.DeviceApprovals]: ["tier1", "tier3", "enterprise"],
[TierFeature.LoginPageBranding]: ["tier1", "tier3", "enterprise"],
[TierFeature.LogExport]: ["tier3", "enterprise"],
[TierFeature.AccessLogs]: ["tier2", "tier3", "enterprise"],
[TierFeature.ActionLogs]: ["tier2", "tier3", "enterprise"],
[TierFeature.RotateCredentials]: ["tier1", "tier2", "tier3", "enterprise"],
[TierFeature.MaintencePage]: ["tier1", "tier2", "tier3", "enterprise"],
[TierFeature.DevicePosture]: ["tier2", "tier3", "enterprise"],
[TierFeature.TwoFactorEnforcement]: [
"tier1",
"tier2",
"tier3",
"enterprise"
],
[TierFeature.SessionDurationPolicies]: [
"tier1",
"tier2",
"tier3",
"enterprise"
],
[TierFeature.PasswordExpirationPolicies]: [
"tier1",
"tier2",
"tier3",
"enterprise"
],
[TierFeature.AutoProvisioning]: ["tier1", "tier3", "enterprise"]
};

View File

@@ -0,0 +1,34 @@
export enum TierId {
STANDARD = "standard"
}
export type TierPriceSet = {
[key in TierId]: string;
};
export const tierPriceSet: TierPriceSet = {
// Free tier matches the freeLimitSet
[TierId.STANDARD]: "price_1RrQ9cD3Ee2Ir7Wmqdy3KBa0"
};
export const tierPriceSetSandbox: TierPriceSet = {
// Free tier matches the freeLimitSet
// when matching tier the keys closer to 0 index are matched first so list the tiers in descending order of value
[TierId.STANDARD]: "price_1RrAYJDCpkOb237By2s1P32m"
};
export function getTierPriceSet(
environment?: string,
sandbox_mode?: boolean
): TierPriceSet {
if (
(process.env.ENVIRONMENT == "prod" &&
process.env.SANDBOX_MODE !== "true") ||
(environment === "prod" && sandbox_mode !== true)
) {
// THIS GETS LOADED CLIENT SIDE AND SERVER SIDE
return tierPriceSet;
} else {
return tierPriceSetSandbox;
}
}

View File

@@ -1,6 +1,8 @@
import { eq, sql, and } from "drizzle-orm"; import { eq, sql, and } from "drizzle-orm";
import { v4 as uuidv4 } from "uuid"; import { v4 as uuidv4 } from "uuid";
import { PutObjectCommand } from "@aws-sdk/client-s3"; import { PutObjectCommand } from "@aws-sdk/client-s3";
import * as fs from "fs/promises";
import * as path from "path";
import { import {
db, db,
usage, usage,
@@ -30,7 +32,11 @@ interface StripeEvent {
} }
export function noop() { export function noop() {
if (build !== "saas") { if (
build !== "saas" ||
!process.env.S3_BUCKET ||
!process.env.LOCAL_FILE_PATH
) {
return true; return true;
} }
return false; return false;
@@ -38,40 +44,31 @@ export function noop() {
export class UsageService { export class UsageService {
private bucketName: string | undefined; private bucketName: string | undefined;
private events: StripeEvent[] = []; private currentEventFile: string | null = null;
private lastUploadTime: number = Date.now(); private currentFileStartTime: number = 0;
private isUploading: boolean = false; private eventsDir: string | undefined;
private uploadingFiles: Set<string> = new Set();
constructor() { constructor() {
if (noop()) { if (noop()) {
return; return;
} }
// this.bucketName = privateConfig.getRawPrivateConfig().stripe?.s3Bucket;
// this.eventsDir = privateConfig.getRawPrivateConfig().stripe?.localFilePath;
this.bucketName = process.env.S3_BUCKET || undefined;
this.eventsDir = process.env.LOCAL_FILE_PATH || undefined;
// this.bucketName = process.env.S3_BUCKET || undefined; // Ensure events directory exists
this.initializeEventsDirectory().then(() => {
this.uploadPendingEventFilesOnStartup();
});
// // Periodically check and upload events // Periodically check for old event files to upload
// setInterval(() => { setInterval(() => {
// this.checkAndUploadEvents().catch((err) => { this.uploadOldEventFiles().catch((err) => {
// logger.error("Error in periodic event upload:", err); logger.error("Error in periodic event file upload:", err);
// }); });
// }, 30000); // every 30 seconds }, 30000); // every 30 seconds
// // Handle graceful shutdown on SIGTERM
// process.on("SIGTERM", async () => {
// logger.info(
// "SIGTERM received, uploading events before shutdown..."
// );
// await this.forceUpload();
// logger.info("Events uploaded, proceeding with shutdown");
// });
// // Handle SIGINT as well (Ctrl+C)
// process.on("SIGINT", async () => {
// logger.info("SIGINT received, uploading events before shutdown...");
// await this.forceUpload();
// logger.info("Events uploaded, proceeding with shutdown");
// process.exit(0);
// });
} }
/** /**
@@ -81,6 +78,85 @@ export class UsageService {
return Math.round(value * 100000000000) / 100000000000; // 11 decimal places return Math.round(value * 100000000000) / 100000000000; // 11 decimal places
} }
private async initializeEventsDirectory(): Promise<void> {
if (!this.eventsDir) {
logger.warn(
"Stripe local file path is not configured, skipping events directory initialization."
);
return;
}
try {
await fs.mkdir(this.eventsDir, { recursive: true });
} catch (error) {
logger.error("Failed to create events directory:", error);
}
}
private async uploadPendingEventFilesOnStartup(): Promise<void> {
if (!this.eventsDir || !this.bucketName) {
logger.warn(
"Stripe local file path or bucket name is not configured, skipping leftover event file upload."
);
return;
}
try {
const files = await fs.readdir(this.eventsDir);
for (const file of files) {
if (file.endsWith(".json")) {
const filePath = path.join(this.eventsDir, file);
try {
const fileContent = await fs.readFile(
filePath,
"utf-8"
);
const events = JSON.parse(fileContent);
if (Array.isArray(events) && events.length > 0) {
// Upload to S3
const uploadCommand = new PutObjectCommand({
Bucket: this.bucketName,
Key: file,
Body: fileContent,
ContentType: "application/json"
});
await s3Client.send(uploadCommand);
// Check if file still exists before unlinking
try {
await fs.access(filePath);
await fs.unlink(filePath);
} catch (unlinkError) {
logger.debug(
`Startup file ${file} was already deleted`
);
}
logger.info(
`Uploaded leftover event file ${file} to S3 with ${events.length} events`
);
} else {
// Remove empty file
try {
await fs.access(filePath);
await fs.unlink(filePath);
} catch (unlinkError) {
logger.debug(
`Empty startup file ${file} was already deleted`
);
}
}
} catch (err) {
logger.error(
`Error processing leftover event file ${file}:`,
err
);
}
}
}
} catch (error) {
logger.error("Failed to scan for leftover event files");
}
}
public async add( public async add(
orgId: string, orgId: string,
featureId: FeatureId, featureId: FeatureId,
@@ -130,9 +206,7 @@ export class UsageService {
} }
// Log event for Stripe // Log event for Stripe
// if (privateConfig.getRawPrivateConfig().flags.usage_reporting) { await this.logStripeEvent(featureId, value, customerId);
// await this.logStripeEvent(featureId, value, customerId);
// }
return usage || null; return usage || null;
} catch (error: any) { } catch (error: any) {
@@ -212,7 +286,7 @@ export class UsageService {
return new Date(date * 1000).toISOString().split("T")[0]; return new Date(date * 1000).toISOString().split("T")[0];
} }
async updateCount( async updateDaily(
orgId: string, orgId: string,
featureId: FeatureId, featureId: FeatureId,
value?: number, value?: number,
@@ -238,6 +312,8 @@ export class UsageService {
value = this.truncateValue(value); value = this.truncateValue(value);
} }
const today = this.getTodayDateString();
let currentUsage: Usage | null = null; let currentUsage: Usage | null = null;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
@@ -251,34 +327,66 @@ export class UsageService {
.limit(1); .limit(1);
if (currentUsage) { if (currentUsage) {
await trx const lastUpdateDate = this.getDateString(
.update(usage) currentUsage.updatedAt
.set({ );
instantaneousValue: value, const currentRunningTotal = currentUsage.latestValue;
updatedAt: Math.floor(Date.now() / 1000) const lastDailyValue = currentUsage.instantaneousValue || 0;
})
.where(eq(usage.usageId, usageId)); if (value == undefined || value === null) {
value = currentUsage.instantaneousValue || 0;
}
if (lastUpdateDate === today) {
// Same day update: replace the daily value
// Remove old daily value from running total, add new value
const newRunningTotal = this.truncateValue(
currentRunningTotal - lastDailyValue + value
);
await trx
.update(usage)
.set({
latestValue: newRunningTotal,
instantaneousValue: value,
updatedAt: Math.floor(Date.now() / 1000)
})
.where(eq(usage.usageId, usageId));
} else {
// New day: add to running total
const newRunningTotal = this.truncateValue(
currentRunningTotal + value
);
await trx
.update(usage)
.set({
latestValue: newRunningTotal,
instantaneousValue: value,
updatedAt: Math.floor(Date.now() / 1000)
})
.where(eq(usage.usageId, usageId));
}
} else { } else {
// First record for this meter // First record for this meter
const meterId = getFeatureMeterId(featureId); const meterId = getFeatureMeterId(featureId);
const truncatedValue = this.truncateValue(value || 0);
await trx.insert(usage).values({ await trx.insert(usage).values({
usageId, usageId,
featureId, featureId,
orgId, orgId,
meterId, meterId,
instantaneousValue: value || 0, instantaneousValue: truncatedValue,
latestValue: value || 0, latestValue: truncatedValue,
updatedAt: Math.floor(Date.now() / 1000) updatedAt: Math.floor(Date.now() / 1000)
}); });
} }
}); });
// if (privateConfig.getRawPrivateConfig().flags.usage_reporting) { await this.logStripeEvent(featureId, value || 0, customerId);
// await this.logStripeEvent(featureId, value || 0, customerId);
// }
} catch (error) { } catch (error) {
logger.error( logger.error(
`Failed to update count usage for ${orgId}/${featureId}:`, `Failed to update daily usage for ${orgId}/${featureId}:`,
error error
); );
} }
@@ -342,58 +450,121 @@ export class UsageService {
} }
}; };
this.addEventToMemory(event); await this.writeEventToFile(event);
await this.checkAndUploadEvents(); await this.checkAndUploadFile();
} }
private addEventToMemory(event: StripeEvent): void { private async writeEventToFile(event: StripeEvent): Promise<void> {
if (!this.bucketName) { if (!this.eventsDir || !this.bucketName) {
logger.warn( logger.warn(
"S3 bucket name is not configured, skipping event storage." "Stripe local file path or bucket name is not configured, skipping event file write."
); );
return; return;
} }
this.events.push(event); if (!this.currentEventFile) {
} this.currentEventFile = this.generateEventFileName();
this.currentFileStartTime = Date.now();
private async checkAndUploadEvents(): Promise<void> {
const now = Date.now();
const timeSinceLastUpload = now - this.lastUploadTime;
// Check if at least 1 minute has passed since last upload
if (timeSinceLastUpload >= 60000 && this.events.length > 0) {
await this.uploadEventsToS3();
}
}
private async uploadEventsToS3(): Promise<void> {
if (!this.bucketName) {
logger.warn(
"S3 bucket name is not configured, skipping S3 upload."
);
return;
} }
if (this.events.length === 0) { const filePath = path.join(this.eventsDir, this.currentEventFile);
return;
}
// Check if already uploading
if (this.isUploading) {
logger.debug("Already uploading events, skipping");
return;
}
this.isUploading = true;
try { try {
// Take a snapshot of current events and clear the array let events: StripeEvent[] = [];
const eventsToUpload = [...this.events];
this.events = [];
this.lastUploadTime = Date.now();
const fileName = this.generateEventFileName(); // Try to read existing file
const fileContent = JSON.stringify(eventsToUpload, null, 2); try {
const fileContent = await fs.readFile(filePath, "utf-8");
events = JSON.parse(fileContent);
} catch (error) {
// File doesn't exist or is empty, start with empty array
events = [];
}
// Add new event
events.push(event);
// Write back to file
await fs.writeFile(filePath, JSON.stringify(events, null, 2));
} catch (error) {
logger.error("Failed to write event to file:", error);
}
}
private async checkAndUploadFile(): Promise<void> {
if (!this.currentEventFile) {
return;
}
const now = Date.now();
const fileAge = now - this.currentFileStartTime;
// Check if file is at least 1 minute old
if (fileAge >= 60000) {
// 60 seconds
await this.uploadFileToS3();
}
}
private async uploadFileToS3(): Promise<void> {
if (!this.bucketName || !this.eventsDir) {
logger.warn(
"Stripe local file path or bucket name is not configured, skipping S3 upload."
);
return;
}
if (!this.currentEventFile) {
return;
}
const fileName = this.currentEventFile;
const filePath = path.join(this.eventsDir, fileName);
// Check if this file is already being uploaded
if (this.uploadingFiles.has(fileName)) {
logger.debug(
`File ${fileName} is already being uploaded, skipping`
);
return;
}
// Mark file as being uploaded
this.uploadingFiles.add(fileName);
try {
// Check if file exists before trying to read it
try {
await fs.access(filePath);
} catch (error) {
logger.debug(
`File ${fileName} does not exist, may have been already processed`
);
this.uploadingFiles.delete(fileName);
// Reset current file if it was this file
if (this.currentEventFile === fileName) {
this.currentEventFile = null;
this.currentFileStartTime = 0;
}
return;
}
// Check if file exists and has content
const fileContent = await fs.readFile(filePath, "utf-8");
const events = JSON.parse(fileContent);
if (events.length === 0) {
// No events to upload, just clean up
try {
await fs.unlink(filePath);
} catch (unlinkError) {
// File may have been already deleted
logger.debug(
`File ${fileName} was already deleted during cleanup`
);
}
this.currentEventFile = null;
this.uploadingFiles.delete(fileName);
return;
}
// Upload to S3 // Upload to S3
const uploadCommand = new PutObjectCommand({ const uploadCommand = new PutObjectCommand({
@@ -405,15 +576,29 @@ export class UsageService {
await s3Client.send(uploadCommand); await s3Client.send(uploadCommand);
// Clean up local file - check if it still exists before unlinking
try {
await fs.access(filePath);
await fs.unlink(filePath);
} catch (unlinkError) {
// File may have been already deleted by another process
logger.debug(
`File ${fileName} was already deleted during upload`
);
}
logger.info( logger.info(
`Uploaded ${fileName} to S3 with ${eventsToUpload.length} events` `Uploaded ${fileName} to S3 with ${events.length} events`
); );
// Reset for next file
this.currentEventFile = null;
this.currentFileStartTime = 0;
} catch (error) { } catch (error) {
logger.error("Failed to upload events to S3:", error); logger.error(`Failed to upload ${fileName} to S3:`, error);
// Note: Events are lost if upload fails. In a production system,
// you might want to add the events back to the array or implement retry logic
} finally { } finally {
this.isUploading = false; // Always remove from uploading set
this.uploadingFiles.delete(fileName);
} }
} }
@@ -498,16 +683,129 @@ export class UsageService {
} }
} }
public async getUsageDaily(
orgId: string,
featureId: FeatureId
): Promise<Usage | null> {
if (noop()) {
return null;
}
await this.updateDaily(orgId, featureId); // Ensure daily usage is updated
return this.getUsage(orgId, featureId);
}
public async forceUpload(): Promise<void> { public async forceUpload(): Promise<void> {
if (this.events.length > 0) { await this.uploadFileToS3();
// Force upload regardless of time }
this.lastUploadTime = 0; // Reset to force upload
await this.uploadEventsToS3(); /**
* Scan the events directory for files older than 1 minute and upload them if not empty.
*/
private async uploadOldEventFiles(): Promise<void> {
if (!this.eventsDir || !this.bucketName) {
logger.warn(
"Stripe local file path or bucket name is not configured, skipping old event file upload."
);
return;
}
try {
const files = await fs.readdir(this.eventsDir);
const now = Date.now();
for (const file of files) {
if (!file.endsWith(".json")) continue;
// Skip files that are already being uploaded
if (this.uploadingFiles.has(file)) {
logger.debug(
`Skipping file ${file} as it's already being uploaded`
);
continue;
}
const filePath = path.join(this.eventsDir, file);
try {
// Check if file still exists before processing
try {
await fs.access(filePath);
} catch (accessError) {
logger.debug(`File ${file} does not exist, skipping`);
continue;
}
const stat = await fs.stat(filePath);
const age = now - stat.mtimeMs;
if (age >= 90000) {
// 1.5 minutes - Mark as being uploaded
this.uploadingFiles.add(file);
try {
const fileContent = await fs.readFile(
filePath,
"utf-8"
);
const events = JSON.parse(fileContent);
if (Array.isArray(events) && events.length > 0) {
// Upload to S3
const uploadCommand = new PutObjectCommand({
Bucket: this.bucketName,
Key: file,
Body: fileContent,
ContentType: "application/json"
});
await s3Client.send(uploadCommand);
// Check if file still exists before unlinking
try {
await fs.access(filePath);
await fs.unlink(filePath);
} catch (unlinkError) {
logger.debug(
`File ${file} was already deleted during interval upload`
);
}
logger.info(
`Interval: Uploaded event file ${file} to S3 with ${events.length} events`
);
// If this was the current event file, reset it
if (this.currentEventFile === file) {
this.currentEventFile = null;
this.currentFileStartTime = 0;
}
} else {
// Remove empty file
try {
await fs.access(filePath);
await fs.unlink(filePath);
} catch (unlinkError) {
logger.debug(
`Empty file ${file} was already deleted`
);
}
}
} finally {
// Always remove from uploading set
this.uploadingFiles.delete(file);
}
}
} catch (err) {
logger.error(
`Interval: Error processing event file ${file}:`,
err
);
// Remove from uploading set on error
this.uploadingFiles.delete(file);
}
}
} catch (err) {
logger.error("Interval: Failed to scan for event files:", err);
} }
} }
public async checkLimitSet( public async checkLimitSet(
orgId: string, orgId: string,
kickSites = false,
featureId?: FeatureId, featureId?: FeatureId,
usage?: Usage, usage?: Usage,
trx: Transaction | typeof db = db trx: Transaction | typeof db = db
@@ -581,6 +879,58 @@ export class UsageService {
break; // Exit early if any limit is exceeded break; // Exit early if any limit is exceeded
} }
} }
// If any limits are exceeded, disconnect all sites for this organization
if (hasExceededLimits && kickSites) {
logger.warn(
`Disconnecting all sites for org ${orgId} due to exceeded limits`
);
// Get all sites for this organization
const orgSites = await trx
.select()
.from(sites)
.where(eq(sites.orgId, orgId));
// Mark all sites as offline and send termination messages
const siteUpdates = orgSites.map((site) => site.siteId);
if (siteUpdates.length > 0) {
// Send termination messages to newt sites
for (const site of orgSites) {
if (site.type === "newt") {
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (newt) {
const payload = {
type: `newt/wg/terminate`,
data: {
reason: "Usage limits exceeded"
}
};
// Don't await to prevent blocking
await sendToClient(newt.newtId, payload).catch(
(error: any) => {
logger.error(
`Failed to send termination message to newt ${newt.newtId}:`,
error
);
}
);
}
}
}
logger.info(
`Disconnected ${orgSites.length} sites for org ${orgId} due to exceeded limits`
);
}
}
} catch (error) { } catch (error) {
logger.error(`Error checking limits for org ${orgId}:`, error); logger.error(`Error checking limits for org ${orgId}:`, error);
} }

View File

@@ -32,7 +32,7 @@ import { resourcePassword } from "@server/db";
import { hashPassword } from "@server/auth/password"; import { hashPassword } from "@server/auth/password";
import { isValidCIDR, isValidIP, isValidUrlGlobPattern } from "../validators"; import { isValidCIDR, isValidIP, isValidUrlGlobPattern } from "../validators";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { tierMatrix } from "../billing/tierMatrix"; import { build } from "@server/build";
export type ProxyResourcesResults = { export type ProxyResourcesResults = {
proxyResource: Resource; proxyResource: Resource;
@@ -212,7 +212,7 @@ export async function updateProxyResources(
} else { } else {
// Update existing resource // Update existing resource
const isLicensed = await isLicensedOrSubscribed(orgId, tierMatrix.maintencePage); const isLicensed = await isLicensedOrSubscribed(orgId);
if (!isLicensed) { if (!isLicensed) {
resourceData.maintenance = undefined; resourceData.maintenance = undefined;
} }
@@ -648,7 +648,7 @@ export async function updateProxyResources(
); );
} }
const isLicensed = await isLicensedOrSubscribed(orgId, tierMatrix.maintencePage); const isLicensed = await isLicensedOrSubscribed(orgId);
if (!isLicensed) { if (!isLicensed) {
resourceData.maintenance = undefined; resourceData.maintenance = undefined;
} }

View File

@@ -20,7 +20,6 @@ import { sendTerminateClient } from "@server/routers/client/terminate";
import { and, eq, notInArray, type InferInsertModel } from "drizzle-orm"; import { and, eq, notInArray, type InferInsertModel } from "drizzle-orm";
import { rebuildClientAssociationsFromClient } from "./rebuildClientAssociations"; import { rebuildClientAssociationsFromClient } from "./rebuildClientAssociations";
import { OlmErrorCodes } from "@server/routers/olm/error"; import { OlmErrorCodes } from "@server/routers/olm/error";
import { tierMatrix } from "./billing/tierMatrix";
export async function calculateUserClientsForOrgs( export async function calculateUserClientsForOrgs(
userId: string, userId: string,
@@ -190,8 +189,7 @@ export async function calculateUserClientsForOrgs(
const niceId = await getUniqueClientName(orgId); const niceId = await getUniqueClientName(orgId);
const isOrgLicensed = await isLicensedOrSubscribed( const isOrgLicensed = await isLicensedOrSubscribed(
userOrg.orgId, userOrg.orgId
tierMatrix.deviceApprovals
); );
const requireApproval = const requireApproval =
build !== "oss" && build !== "oss" &&

View File

@@ -107,11 +107,6 @@ export class Config {
process.env.MAXMIND_ASN_PATH = parsedConfig.server.maxmind_asn_path; process.env.MAXMIND_ASN_PATH = parsedConfig.server.maxmind_asn_path;
} }
process.env.DISABLE_ENTERPRISE_FEATURES = parsedConfig.flags
?.disable_enterprise_features
? "true"
: "false";
this.rawConfig = parsedConfig; this.rawConfig = parsedConfig;
} }

View File

@@ -182,7 +182,7 @@ export async function createUserAccountOrg(
const customerId = await createCustomer(orgId, userEmail); const customerId = await createCustomer(orgId, userEmail);
if (customerId) { if (customerId) {
await usageService.updateCount(orgId, FeatureId.USERS, 1, customerId); // Only 1 because we are crating the org await usageService.updateDaily(orgId, FeatureId.USERS, 1, customerId); // Only 1 because we are crating the org
} }
return { return {

View File

@@ -1,8 +1,3 @@
import { Tier } from "@server/types/Tiers"; export async function isLicensedOrSubscribed(orgId: string): Promise<boolean> {
export async function isLicensedOrSubscribed(
orgId: string,
tiers: Tier[]
): Promise<boolean> {
return false; return false;
} }

View File

@@ -1,8 +0,0 @@
import { Tier } from "@server/types/Tiers";
export async function isSubscribed(
orgId: string,
tiers: Tier[]
): Promise<boolean> {
return false;
}

View File

@@ -331,8 +331,7 @@ export const configSchema = z
disable_local_sites: z.boolean().optional(), disable_local_sites: z.boolean().optional(),
disable_basic_wireguard_sites: z.boolean().optional(), disable_basic_wireguard_sites: z.boolean().optional(),
disable_config_managed_domains: z.boolean().optional(), disable_config_managed_domains: z.boolean().optional(),
disable_product_help_banners: z.boolean().optional(), disable_product_help_banners: z.boolean().optional()
disable_enterprise_features: z.boolean().optional()
}) })
.optional(), .optional(),
dns: z dns: z

View File

@@ -29,4 +29,3 @@ export * from "./verifyUserIsOrgOwner";
export * from "./verifySiteResourceAccess"; export * from "./verifySiteResourceAccess";
export * from "./logActionAudit"; export * from "./logActionAudit";
export * from "./verifyOlmAccess"; export * from "./verifyOlmAccess";
export * from "./verifyLimits";

View File

@@ -4,6 +4,7 @@ import { apiKeyOrg } from "@server/db";
import { and, eq } from "drizzle-orm"; import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
export async function verifyApiKeyOrgAccess( export async function verifyApiKeyOrgAccess(
req: Request, req: Request,

View File

@@ -1,43 +0,0 @@
import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { usageService } from "@server/lib/billing/usageService";
import { build } from "@server/build";
export async function verifyLimits(
req: Request,
res: Response,
next: NextFunction
) {
if (build != "saas") {
return next();
}
const orgId = req.userOrgId || req.apiKeyOrg?.orgId || req.params.orgId;
if (!orgId) {
return next(); // its fine if we silently fail here because this is not critical to operation or security and its better user experience if we dont fail
}
try {
const reject = await usageService.checkLimitSet(orgId);
if (reject) {
return next(
createHttpError(
HttpCode.PAYMENT_REQUIRED,
"Organization has exceeded its usage limits. Please upgrade your plan or contact support."
)
);
}
return next();
} catch (e) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error checking limits"
)
);
}
}

View File

@@ -11,59 +11,46 @@
* This file is not licensed under the AGPLv3. * This file is not licensed under the AGPLv3.
*/ */
import { getTierPriceSet } from "@server/lib/billing/tiers";
import { getOrgSubscriptionsData } from "@server/private/routers/billing/getOrgSubscriptions";
import { build } from "@server/build"; import { build } from "@server/build";
import { db, customers, subscriptions } from "@server/db";
import { Tier } from "@server/types/Tiers";
import { eq, and, ne } from "drizzle-orm";
export async function getOrgTierData( export async function getOrgTierData(
orgId: string orgId: string
): Promise<{ tier: Tier | null; active: boolean }> { ): Promise<{ tier: string | null; active: boolean }> {
let tier: Tier | null = null; let tier = null;
let active = false; let active = false;
if (build !== "saas") { if (build !== "saas") {
return { tier, active }; return { tier, active };
} }
try { // TODO: THIS IS INEFFICIENT!!! WE SHOULD IMPROVE HOW WE STORE TIERS WITH SUBSCRIPTIONS AND RETRIEVE THEM
// Get customer for org
const [customer] = await db
.select()
.from(customers)
.where(eq(customers.orgId, orgId))
.limit(1);
if (customer) { const subscriptionsWithItems = await getOrgSubscriptionsData(orgId);
// Query for active subscriptions that are not license type
const [subscription] = await db
.select()
.from(subscriptions)
.where(
and(
eq(subscriptions.customerId, customer.customerId),
eq(subscriptions.status, "active"),
ne(subscriptions.type, "license")
)
)
.limit(1);
if (subscription) { for (const { subscription, items } of subscriptionsWithItems) {
// Validate that subscription.type is one of the expected tier values if (items && items.length > 0) {
if ( const tierPriceSet = getTierPriceSet();
subscription.type === "tier1" || // Iterate through tiers in order (earlier keys are higher tiers)
subscription.type === "tier2" || for (const [tierId, priceId] of Object.entries(tierPriceSet)) {
subscription.type === "tier3" // Check if any subscription item matches this tier's price ID
) { const matchingItem = items.find((item) => item.priceId === priceId);
tier = subscription.type; if (matchingItem) {
active = true; tier = tierId;
break;
} }
} }
} }
} catch (error) {
// If org not found or error occurs, return null tier and inactive
// This is acceptable behavior as per the function signature
}
if (subscription && subscription.status === "active") {
active = true;
}
// If we found a tier and active subscription, we can stop
if (tier && active) {
break;
}
}
return { tier, active }; return { tier, active };
} }

View File

@@ -13,6 +13,8 @@
import { build } from "@server/build"; import { build } from "@server/build";
import { db, Org, orgs, ResourceSession, sessions, users } from "@server/db"; import { db, Org, orgs, ResourceSession, sessions, users } from "@server/db";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import license from "#private/license/license"; import license from "#private/license/license";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { import {
@@ -78,8 +80,6 @@ export async function checkOrgAccessPolicy(
} }
} }
// TODO: check that the org is subscribed
// get the needed data // get the needed data
if (!props.org) { if (!props.org) {

View File

@@ -125,6 +125,16 @@ export class PrivateConfig {
this.rawPrivateConfig.server.reo_client_id; this.rawPrivateConfig.server.reo_client_id;
} }
if (this.rawPrivateConfig.stripe?.s3Bucket) {
process.env.S3_BUCKET = this.rawPrivateConfig.stripe.s3Bucket;
}
if (this.rawPrivateConfig.stripe?.localFilePath) {
process.env.LOCAL_FILE_PATH =
this.rawPrivateConfig.stripe.localFilePath;
}
if (this.rawPrivateConfig.stripe?.s3Region) {
process.env.S3_REGION = this.rawPrivateConfig.stripe.s3Region;
}
if (this.rawPrivateConfig.flags.use_pangolin_dns) { if (this.rawPrivateConfig.flags.use_pangolin_dns) {
process.env.USE_PANGOLIN_DNS = process.env.USE_PANGOLIN_DNS =
this.rawPrivateConfig.flags.use_pangolin_dns.toString(); this.rawPrivateConfig.flags.use_pangolin_dns.toString();

View File

@@ -13,20 +13,18 @@
import { build } from "@server/build"; import { build } from "@server/build";
import license from "#private/license/license"; import license from "#private/license/license";
import { isSubscribed } from "#private/lib/isSubscribed"; import { getOrgTierData } from "#private/lib/billing";
import { Tier } from "@server/types/Tiers"; import { TierId } from "@server/lib/billing/tiers";
export async function isLicensedOrSubscribed( export async function isLicensedOrSubscribed(orgId: string): Promise<boolean> {
orgId: string,
tiers: Tier[]
): Promise<boolean> {
if (build === "enterprise") { if (build === "enterprise") {
return await license.isUnlocked(); return await license.isUnlocked();
} }
if (build === "saas") { if (build === "saas") {
return isSubscribed(orgId, tiers); const { tier } = await getOrgTierData(orgId);
return tier === TierId.STANDARD;
} }
return false; return false;
} }

View File

@@ -1,29 +0,0 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
import { Tier } from "@server/types/Tiers";
export async function isSubscribed(
orgId: string,
tiers: Tier[]
): Promise<boolean> {
if (build === "saas") {
const { tier, active } = await getOrgTierData(orgId);
const isTier = (tier && tiers.includes(tier)) || false;
return active && isTier;
}
return false;
}

View File

@@ -95,7 +95,7 @@ export const privateConfigSchema = z.object({
.object({ .object({
enable_redis: z.boolean().optional().default(false), enable_redis: z.boolean().optional().default(false),
use_pangolin_dns: z.boolean().optional().default(false), use_pangolin_dns: z.boolean().optional().default(false),
use_org_only_idp: z.boolean().optional().default(false), use_org_only_idp: z.boolean().optional().default(false)
}) })
.optional() .optional()
.prefault({}), .prefault({}),
@@ -176,9 +176,9 @@ export const privateConfigSchema = z.object({
.string() .string()
.optional() .optional()
.transform(getEnvOrYaml("STRIPE_WEBHOOK_SECRET")), .transform(getEnvOrYaml("STRIPE_WEBHOOK_SECRET")),
// s3Bucket: z.string(), s3Bucket: z.string(),
// s3Region: z.string().default("us-east-1"), s3Region: z.string().default("us-east-1"),
// localFilePath: z.string().optional() localFilePath: z.string()
}) })
.optional() .optional()
}); });

View File

@@ -16,61 +16,46 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { build } from "@server/build"; import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing"; import { getOrgTierData } from "#private/lib/billing";
import { Tier } from "@server/types/Tiers";
export function verifyValidSubscription(tiers: Tier[]) {
return async function (
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
if (build != "saas") {
return next();
}
const orgId =
req.params.orgId ||
req.body.orgId ||
req.query.orgId ||
req.userOrgId;
if (!orgId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Organization ID is required to verify subscription"
)
);
}
const { tier, active } = await getOrgTierData(orgId);
const isTier = tiers.includes(tier as Tier);
if (!active) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Organization does not have an active subscription"
)
);
}
if (!isTier) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Organization subscription tier does not have access to this feature"
)
);
}
export async function verifyValidSubscription(
req: Request,
res: Response,
next: NextFunction
) {
try {
if (build != "saas") {
return next(); return next();
} catch (e) { }
const orgId = req.params.orgId || req.body.orgId || req.query.orgId || req.userOrgId;
if (!orgId) {
return next( return next(
createHttpError( createHttpError(
HttpCode.INTERNAL_SERVER_ERROR, HttpCode.BAD_REQUEST,
"Error verifying subscription" "Organization ID is required to verify subscription"
) )
); );
} }
};
const tier = await getOrgTierData(orgId);
if (!tier.active) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Organization does not have an active subscription"
)
);
}
return next();
} catch (e) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying subscription"
)
);
}
} }

View File

@@ -19,6 +19,8 @@ import { fromError } from "zod-validation-error";
import type { Request, Response, NextFunction } from "express"; import type { Request, Response, NextFunction } from "express";
import { build } from "@server/build"; import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { import {
approvals, approvals,
clients, clients,
@@ -219,6 +221,19 @@ export async function listApprovals(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const approvalsList = await queryApprovals( const approvalsList = await queryApprovals(
orgId.toString(), orgId.toString(),
limit, limit,

View File

@@ -17,7 +17,10 @@ import createHttpError from "http-errors";
import { z } from "zod"; import { z } from "zod";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { build } from "@server/build";
import { approvals, clients, db, orgs, type Approval } from "@server/db"; import { approvals, clients, db, orgs, type Approval } from "@server/db";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import response from "@server/lib/response"; import response from "@server/lib/response";
import { and, eq, type InferInsertModel } from "drizzle-orm"; import { and, eq, type InferInsertModel } from "drizzle-orm";
import type { NextFunction, Request, Response } from "express"; import type { NextFunction, Request, Response } from "express";
@@ -61,6 +64,20 @@ export async function processPendingApproval(
} }
const { orgId, approvalId } = parsedParams.data; const { orgId, approvalId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const updateData = parsedBody.data; const updateData = parsedBody.data;
const approval = await db const approval = await db

View File

@@ -13,3 +13,4 @@
export * from "./transferSession"; export * from "./transferSession";
export * from "./getSessionTransferToken"; export * from "./getSessionTransferToken";
export * from "./quickStart";

View File

@@ -0,0 +1,585 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { NextFunction, Request, Response } from "express";
import {
account,
db,
domainNamespaces,
domains,
exitNodes,
newts,
newtSessions,
orgs,
passwordResetTokens,
Resource,
resourcePassword,
resourcePincode,
resources,
resourceWhitelist,
roleResources,
roles,
roleSites,
sites,
targetHealthCheck,
targets,
userResources,
userSites
} from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { users } from "@server/db";
import { fromError } from "zod-validation-error";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import { SqliteError } from "better-sqlite3";
import { eq, and, sql } from "drizzle-orm";
import moment from "moment";
import { generateId } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import logger from "@server/logger";
import { hashPassword } from "@server/auth/password";
import { UserType } from "@server/types/UserTypes";
import { createUserAccountOrg } from "@server/lib/createUserAccountOrg";
import { sendEmail } from "@server/emails";
import WelcomeQuickStart from "@server/emails/templates/WelcomeQuickStart";
import { alphabet, generateRandomString } from "oslo/crypto";
import { createDate, TimeSpan } from "oslo";
import { getUniqueResourceName, getUniqueSiteName } from "@server/db/names";
import { pickPort } from "@server/routers/target/helpers";
import { addTargets } from "@server/routers/newt/targets";
import { isTargetValid } from "@server/lib/validators";
import { listExitNodes } from "#private/lib/exitNodes";
const bodySchema = z.object({
email: z.email().toLowerCase(),
ip: z.string().refine(isTargetValid),
method: z.enum(["http", "https"]),
port: z.int().min(1).max(65535),
pincode: z
.string()
.regex(/^\d{6}$/)
.optional(),
password: z.string().min(4).max(100).optional(),
enableWhitelist: z.boolean().optional().default(true),
animalId: z.string() // This is actually the secret key for the backend
});
export type QuickStartBody = z.infer<typeof bodySchema>;
export type QuickStartResponse = {
newtId: string;
newtSecret: string;
resourceUrl: string;
completeSignUpLink: string;
};
const DEMO_UBO_KEY = "b460293f-347c-4b30-837d-4e06a04d5a22";
export async function quickStart(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const {
email,
ip,
method,
port,
pincode,
password,
enableWhitelist,
animalId
} = parsedBody.data;
try {
const tokenValidation = validateTokenOnApi(animalId);
if (!tokenValidation.isValid) {
logger.warn(
`Quick start failed for ${email} token ${animalId}: ${tokenValidation.message}`
);
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid or expired token"
)
);
}
if (animalId === DEMO_UBO_KEY) {
if (email !== "mehrdad@getubo.com") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid email for demo Ubo key"
)
);
}
const [existing] = await db
.select()
.from(users)
.where(
and(
eq(users.email, email),
eq(users.type, UserType.Internal)
)
);
if (existing) {
// delete the user if it already exists
await db.delete(users).where(eq(users.userId, existing.userId));
const orgId = `org_${existing.userId}`;
await db.delete(orgs).where(eq(orgs.orgId, orgId));
}
}
const tempPassword = generateId(15);
const passwordHash = await hashPassword(tempPassword);
const userId = generateId(15);
// TODO: see if that user already exists?
// Create the sandbox user
const existing = await db
.select()
.from(users)
.where(
and(eq(users.email, email), eq(users.type, UserType.Internal))
);
if (existing && existing.length > 0) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"A user with that email address already exists"
)
);
}
let newtId: string;
let secret: string;
let fullDomain: string;
let resource: Resource;
let completeSignUpLink: string;
await db.transaction(async (trx) => {
await trx.insert(users).values({
userId: userId,
type: UserType.Internal,
username: email,
email: email,
passwordHash,
dateCreated: moment().toISOString()
});
// create user"s account
await trx.insert(account).values({
userId
});
});
const { success, error, org } = await createUserAccountOrg(
userId,
email
);
if (!success) {
if (error) {
throw new Error(error);
}
throw new Error("Failed to create user account and organization");
}
if (!org) {
throw new Error("Failed to create user account and organization");
}
const orgId = org.orgId;
await db.transaction(async (trx) => {
const token = generateRandomString(
8,
alphabet("0-9", "A-Z", "a-z")
);
await trx
.delete(passwordResetTokens)
.where(eq(passwordResetTokens.userId, userId));
const tokenHash = await hashPassword(token);
await trx.insert(passwordResetTokens).values({
userId: userId,
email: email,
tokenHash,
expiresAt: createDate(new TimeSpan(7, "d")).getTime()
});
// // Create the sandbox newt
// const newClientAddress = await getNextAvailableClientSubnet(orgId);
// if (!newClientAddress) {
// throw new Error("No available subnet found");
// }
// const clientAddress = newClientAddress.split("/")[0];
newtId = generateId(15);
secret = generateId(48);
// Create the sandbox site
const siteNiceId = await getUniqueSiteName(orgId);
const siteName = `First Site`;
// pick a random exit node
const exitNodesList = await listExitNodes(orgId);
// select a random exit node
const randomExitNode =
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
if (!randomExitNode) {
throw new Error("No exit nodes available");
}
const [newSite] = await trx
.insert(sites)
.values({
orgId,
exitNodeId: randomExitNode.exitNodeId,
name: siteName,
niceId: siteNiceId,
// address: clientAddress,
type: "newt",
dockerSocketEnabled: true
})
.returning();
const siteId = newSite.siteId;
const adminRole = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (adminRole.length === 0) {
throw new Error("Admin role not found");
}
await trx.insert(roleSites).values({
roleId: adminRole[0].roleId,
siteId: newSite.siteId
});
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the site
await trx.insert(userSites).values({
userId: req.user?.userId!,
siteId: newSite.siteId
});
}
// add the peer to the exit node
const secretHash = await hashPassword(secret!);
await trx.insert(newts).values({
newtId: newtId!,
secretHash,
siteId: newSite.siteId,
dateCreated: moment().toISOString()
});
const [randomNamespace] = await trx
.select()
.from(domainNamespaces)
.orderBy(sql`RANDOM()`)
.limit(1);
if (!randomNamespace) {
throw new Error("No domain namespace available");
}
const [randomNamespaceDomain] = await trx
.select()
.from(domains)
.where(eq(domains.domainId, randomNamespace.domainId))
.limit(1);
if (!randomNamespaceDomain) {
throw new Error("No domain found for the namespace");
}
const resourceNiceId = await getUniqueResourceName(orgId);
// Create sandbox resource
const subdomain = `${resourceNiceId}-${generateId(5)}`;
fullDomain = `${subdomain}.${randomNamespaceDomain.baseDomain}`;
const resourceName = `First Resource`;
const newResource = await trx
.insert(resources)
.values({
niceId: resourceNiceId,
fullDomain,
domainId: randomNamespaceDomain.domainId,
orgId,
name: resourceName,
subdomain,
http: true,
protocol: "tcp",
ssl: true,
sso: false,
emailWhitelistEnabled: enableWhitelist
})
.returning();
await trx.insert(roleResources).values({
roleId: adminRole[0].roleId,
resourceId: newResource[0].resourceId
});
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the resource
await trx.insert(userResources).values({
userId: req.user?.userId!,
resourceId: newResource[0].resourceId
});
}
resource = newResource[0];
// Create the sandbox target
const { internalPort, targetIps } = await pickPort(siteId!, trx);
if (!internalPort) {
throw new Error("No available internal port");
}
const newTarget = await trx
.insert(targets)
.values({
resourceId: resource.resourceId,
siteId: siteId!,
internalPort,
ip,
method,
port,
enabled: true
})
.returning();
const newHealthcheck = await trx
.insert(targetHealthCheck)
.values({
targetId: newTarget[0].targetId,
hcEnabled: false
})
.returning();
// add the new target to the targetIps array
targetIps.push(`${ip}/32`);
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, siteId!))
.limit(1);
await addTargets(
newt.newtId,
newTarget,
newHealthcheck,
resource.protocol
);
// Set resource pincode if provided
if (pincode) {
await trx
.delete(resourcePincode)
.where(
eq(resourcePincode.resourceId, resource!.resourceId)
);
const pincodeHash = await hashPassword(pincode);
await trx.insert(resourcePincode).values({
resourceId: resource!.resourceId,
pincodeHash,
digitLength: 6
});
}
// Set resource password if provided
if (password) {
await trx
.delete(resourcePassword)
.where(
eq(resourcePassword.resourceId, resource!.resourceId)
);
const passwordHash = await hashPassword(password);
await trx.insert(resourcePassword).values({
resourceId: resource!.resourceId,
passwordHash
});
}
// Set resource OTP if whitelist is enabled
if (enableWhitelist) {
await trx.insert(resourceWhitelist).values({
email,
resourceId: resource!.resourceId
});
}
completeSignUpLink = `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}&token=${token}`;
// Store token for email outside transaction
await sendEmail(
WelcomeQuickStart({
username: email,
link: completeSignUpLink,
fallbackLink: `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}`,
resourceMethod: method,
resourceHostname: ip,
resourcePort: port,
resourceUrl: `https://${fullDomain}`,
cliCommand: `newt --id ${newtId} --secret ${secret}`
}),
{
to: email,
from: config.getNoReplyEmail(),
subject: `Access your Pangolin dashboard and resources`
}
);
});
return response<QuickStartResponse>(res, {
data: {
newtId: newtId!,
newtSecret: secret!,
resourceUrl: `https://${fullDomain!}`,
completeSignUpLink: completeSignUpLink!
},
success: true,
error: false,
message: "Quick start completed successfully",
status: HttpCode.OK
});
} catch (e) {
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
`Account already exists with that email. Email: ${email}. IP: ${req.ip}.`
);
}
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"A user with that email address already exists"
)
);
} else {
logger.error(e);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to do quick start"
)
);
}
}
}
const BACKEND_SECRET_KEY = "4f9b6000-5d1a-11f0-9de7-ff2cc032f501";
/**
* Validates a token received from the frontend.
* @param {string} token The validation token from the request.
* @returns {{ isValid: boolean; message: string }} An object indicating if the token is valid.
*/
const validateTokenOnApi = (
token: string
): { isValid: boolean; message: string } => {
if (token === DEMO_UBO_KEY) {
// Special case for demo UBO key
return { isValid: true, message: "Demo UBO key is valid." };
}
if (!token) {
return { isValid: false, message: "Error: No token provided." };
}
try {
// 1. Decode the base64 string
const decodedB64 = atob(token);
// 2. Reverse the character code manipulation
const deobfuscated = decodedB64
.split("")
.map((char) => String.fromCharCode(char.charCodeAt(0) - 5)) // Reverse the shift
.join("");
// 3. Split the data to get the original secret and timestamp
const parts = deobfuscated.split("|");
if (parts.length !== 2) {
throw new Error("Invalid token format.");
}
const receivedKey = parts[0];
const tokenTimestamp = parseInt(parts[1], 10);
// 4. Check if the secret key matches
if (receivedKey !== BACKEND_SECRET_KEY) {
return { isValid: false, message: "Invalid token: Key mismatch." };
}
// 5. Check if the timestamp is recent (e.g., within 30 seconds) to prevent replay attacks
const now = Date.now();
const timeDifference = now - tokenTimestamp;
if (timeDifference > 30000) {
// 30 seconds
return { isValid: false, message: "Invalid token: Expired." };
}
if (timeDifference < 0) {
// Timestamp is in the future
return {
isValid: false,
message: "Invalid token: Timestamp is in the future."
};
}
// If all checks pass, the token is valid
return { isValid: true, message: "Token is valid!" };
} catch (error) {
// This will catch errors from atob (if not valid base64) or other issues.
return {
isValid: false,
message: `Error: ${(error as Error).message}`
};
}
};

View File

@@ -1,268 +0,0 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { customers, db, subscriptions, subscriptionItems } from "@server/db";
import { eq, and, or } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import stripe from "#private/lib/stripe";
import {
getHomeLabFeaturePriceSet,
getScaleFeaturePriceSet,
getStarterFeaturePriceSet,
FeatureId,
type FeaturePriceSet
} from "@server/lib/billing";
import { getLineItems } from "@server/lib/billing/getLineItems";
const changeTierSchema = z.strictObject({
orgId: z.string()
});
const changeTierBodySchema = z.strictObject({
tier: z.enum(["tier1", "tier2", "tier3"])
});
export async function changeTier(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = changeTierSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const parsedBody = changeTierBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { tier } = parsedBody.data;
// Get the customer for this org
const [customer] = await db
.select()
.from(customers)
.where(eq(customers.orgId, orgId))
.limit(1);
if (!customer) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No customer found for this organization"
)
);
}
// Get the active subscription for this customer
const [subscription] = await db
.select()
.from(subscriptions)
.where(
and(
eq(subscriptions.customerId, customer.customerId),
eq(subscriptions.status, "active"),
or(
eq(subscriptions.type, "tier1"),
eq(subscriptions.type, "tier2"),
eq(subscriptions.type, "tier3")
)
)
)
.limit(1);
if (!subscription) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No active subscription found for this organization"
)
);
}
// Get the target tier's price set
let targetPriceSet: FeaturePriceSet;
if (tier === "tier1") {
targetPriceSet = getHomeLabFeaturePriceSet();
} else if (tier === "tier2") {
targetPriceSet = getStarterFeaturePriceSet();
} else if (tier === "tier3") {
targetPriceSet = getScaleFeaturePriceSet();
} else {
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid tier"));
}
// Get current subscription items from our database
const currentItems = await db
.select()
.from(subscriptionItems)
.where(
eq(
subscriptionItems.subscriptionId,
subscription.subscriptionId
)
);
if (currentItems.length === 0) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No subscription items found"
)
);
}
// Retrieve the full subscription from Stripe to get item IDs
const stripeSubscription = await stripe!.subscriptions.retrieve(
subscription.subscriptionId
);
// Determine if we're switching between different products
// tier1 uses TIER1 product, tier2/tier3 use USERS product
const currentTier = subscription.type;
const switchingProducts =
(currentTier === "tier1" &&
(tier === "tier2" || tier === "tier3")) ||
((currentTier === "tier2" || currentTier === "tier3") &&
tier === "tier1");
let updatedSubscription;
if (switchingProducts) {
// When switching between different products, we need to:
// 1. Delete old subscription items
// 2. Add new subscription items
logger.info(
`Switching products from ${currentTier} to ${tier} for subscription ${subscription.subscriptionId}`
);
// Build array to delete all existing items and add new ones
const itemsToUpdate: any[] = [];
// Mark all existing items for deletion
for (const stripeItem of stripeSubscription.items.data) {
itemsToUpdate.push({
id: stripeItem.id,
deleted: true
});
}
// Add new items for the target tier
const newLineItems = await getLineItems(targetPriceSet, orgId);
for (const lineItem of newLineItems) {
itemsToUpdate.push(lineItem);
}
updatedSubscription = await stripe!.subscriptions.update(
subscription.subscriptionId,
{
items: itemsToUpdate,
proration_behavior: "create_prorations"
}
);
} else {
// Same product, different price tier (tier2 <-> tier3)
// We can simply update the price
logger.info(
`Updating price from ${currentTier} to ${tier} for subscription ${subscription.subscriptionId}`
);
const itemsToUpdate = stripeSubscription.items.data.map(
(stripeItem) => {
// Find the corresponding item in our database
const dbItem = currentItems.find(
(item) => item.priceId === stripeItem.price.id
);
if (!dbItem) {
// Keep the existing item unchanged if we can't find it
return {
id: stripeItem.id,
price: stripeItem.price.id,
quantity: stripeItem.quantity
};
}
// Map to the corresponding feature in the new tier
const newPriceId = targetPriceSet[FeatureId.USERS];
if (newPriceId) {
return {
id: stripeItem.id,
price: newPriceId,
quantity: stripeItem.quantity
};
}
// If no mapping found, keep existing
return {
id: stripeItem.id,
price: stripeItem.price.id,
quantity: stripeItem.quantity
};
}
);
updatedSubscription = await stripe!.subscriptions.update(
subscription.subscriptionId,
{
items: itemsToUpdate,
proration_behavior: "create_prorations"
}
);
}
logger.info(
`Successfully changed tier to ${tier} for org ${orgId}, subscription ${subscription.subscriptionId}`
);
return response<{ subscriptionId: string; newTier: string }>(res, {
data: {
subscriptionId: updatedSubscription.id,
newTier: tier
},
success: true,
error: false,
message: "Tier change successful",
status: HttpCode.OK
});
} catch (error) {
logger.error("Error changing tier:", error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred while changing tier"
)
);
}
}

View File

@@ -22,23 +22,14 @@ import logger from "@server/logger";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import stripe from "#private/lib/stripe"; import stripe from "#private/lib/stripe";
import { import { getLineItems, getStandardFeaturePriceSet } from "@server/lib/billing";
getHomeLabFeaturePriceSet, import { getTierPriceSet, TierId } from "@server/lib/billing/tiers";
getScaleFeaturePriceSet,
getStarterFeaturePriceSet
} from "@server/lib/billing";
import { getLineItems } from "@server/lib/billing/getLineItems";
import Stripe from "stripe";
const createCheckoutSessionSchema = z.strictObject({ const createCheckoutSessionSchema = z.strictObject({
orgId: z.string() orgId: z.string()
}); });
const createCheckoutSessionBodySchema = z.strictObject({ export async function createCheckoutSessionSAAS(
tier: z.enum(["tier1", "tier2", "tier3"])
});
export async function createCheckoutSession(
req: Request, req: Request,
res: Response, res: Response,
next: NextFunction next: NextFunction
@@ -56,18 +47,6 @@ export async function createCheckoutSession(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
const parsedBody = createCheckoutSessionBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { tier } = parsedBody.data;
// check if we already have a customer for this org // check if we already have a customer for this org
const [customer] = await db const [customer] = await db
.select() .select()
@@ -86,26 +65,20 @@ export async function createCheckoutSession(
); );
} }
let lineItems: Stripe.Checkout.SessionCreateParams.LineItem[]; const standardTierPrice = getTierPriceSet()[TierId.STANDARD];
if (tier === "tier1") {
lineItems = await getLineItems(getHomeLabFeaturePriceSet(), orgId);
} else if (tier === "tier2") {
lineItems = await getLineItems(getStarterFeaturePriceSet(), orgId);
} else if (tier === "tier3") {
lineItems = await getLineItems(getScaleFeaturePriceSet(), orgId);
} else {
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid plan"));
}
logger.debug(`Line items: ${JSON.stringify(lineItems)}`);
const session = await stripe!.checkout.sessions.create({ const session = await stripe!.checkout.sessions.create({
client_reference_id: orgId, // So we can look it up the org later on the webhook client_reference_id: orgId, // So we can look it up the org later on the webhook
billing_address_collection: "required", billing_address_collection: "required",
line_items: lineItems, line_items: [
{
price: standardTierPrice, // Use the standard tier
quantity: 1
},
...getLineItems(getStandardFeaturePriceSet())
], // Start with the standard feature set that matches the free limits
customer: customer.customerId, customer: customer.customerId,
mode: "subscription", mode: "subscription",
allow_promotion_codes: true,
success_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing?success=true&session_id={CHECKOUT_SESSION_ID}`, success_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing?success=true&session_id={CHECKOUT_SESSION_ID}`,
cancel_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing?canceled=true` cancel_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing?canceled=true`
}); });

View File

@@ -1,297 +0,0 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { SubscriptionType } from "./hooks/getSubType";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { Tier } from "@server/types/Tiers";
import logger from "@server/logger";
import { db, idp, idpOrg, loginPage, loginPageBranding, loginPageBrandingOrg, loginPageOrg, orgs, resources, roles } from "@server/db";
import { eq } from "drizzle-orm";
export async function handleTierChange(
orgId: string,
newTier: SubscriptionType | null,
previousTier?: SubscriptionType | null
): Promise<void> {
logger.info(
`Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}`
);
// License subscriptions are handled separately and don't use the tier matrix
if (newTier === "license") {
logger.debug(
`New tier is license for org ${orgId}, no feature lifecycle handling needed`
);
return;
}
// If newTier is null, treat as free tier - disable all features
if (newTier === null) {
logger.info(
`Org ${orgId} is reverting to free tier, disabling all paid features`
);
// Disable all features in the tier matrix
for (const [featureKey] of Object.entries(tierMatrix)) {
const feature = featureKey as TierFeature;
logger.info(
`Feature ${feature} is not available in free tier for org ${orgId}. Disabling...`
);
await disableFeature(orgId, feature);
}
logger.info(
`Completed free tier feature lifecycle handling for org ${orgId}`
);
return;
}
// Get the tier (cast as Tier since we've ruled out "license" and null)
const tier = newTier as Tier;
// Check each feature in the tier matrix
for (const [featureKey, allowedTiers] of Object.entries(tierMatrix)) {
const feature = featureKey as TierFeature;
const isFeatureAvailable = allowedTiers.includes(tier);
if (!isFeatureAvailable) {
logger.info(
`Feature ${feature} is not available in tier ${tier} for org ${orgId}. Disabling...`
);
await disableFeature(orgId, feature);
} else {
logger.debug(
`Feature ${feature} is available in tier ${tier} for org ${orgId}`
);
}
}
logger.info(
`Completed tier change feature lifecycle handling for org ${orgId}`
);
}
async function disableFeature(
orgId: string,
feature: TierFeature
): Promise<void> {
try {
switch (feature) {
case TierFeature.OrgOidc:
await disableOrgOidc(orgId);
break;
case TierFeature.LoginPageDomain:
await disableLoginPageDomain(orgId);
break;
case TierFeature.DeviceApprovals:
await disableDeviceApprovals(orgId);
break;
case TierFeature.LoginPageBranding:
await disableLoginPageBranding(orgId);
break;
case TierFeature.LogExport:
await disableLogExport(orgId);
break;
case TierFeature.AccessLogs:
await disableAccessLogs(orgId);
break;
case TierFeature.ActionLogs:
await disableActionLogs(orgId);
break;
case TierFeature.RotateCredentials:
await disableRotateCredentials(orgId);
break;
case TierFeature.MaintencePage:
await disableMaintencePage(orgId);
break;
case TierFeature.DevicePosture:
await disableDevicePosture(orgId);
break;
case TierFeature.TwoFactorEnforcement:
await disableTwoFactorEnforcement(orgId);
break;
case TierFeature.SessionDurationPolicies:
await disableSessionDurationPolicies(orgId);
break;
case TierFeature.PasswordExpirationPolicies:
await disablePasswordExpirationPolicies(orgId);
break;
case TierFeature.AutoProvisioning:
await disableAutoProvisioning(orgId);
break;
default:
logger.warn(
`Unknown feature ${feature} for org ${orgId}, skipping`
);
}
logger.info(
`Successfully disabled feature ${feature} for org ${orgId}`
);
} catch (error) {
logger.error(
`Error disabling feature ${feature} for org ${orgId}:`,
error
);
throw error;
}
}
async function disableOrgOidc(orgId: string): Promise<void> {}
async function disableDeviceApprovals(orgId: string): Promise<void> {
await db
.update(roles)
.set({ requireDeviceApproval: false })
.where(eq(roles.orgId, orgId));
logger.info(`Disabled device approvals on all roles for org ${orgId}`);
}
async function disableLoginPageBranding(orgId: string): Promise<void> {
const [existingBranding] = await db
.select()
.from(loginPageBrandingOrg)
.where(eq(loginPageBrandingOrg.orgId, orgId));
if (existingBranding) {
await db
.delete(loginPageBranding)
.where(
eq(
loginPageBranding.loginPageBrandingId,
existingBranding.loginPageBrandingId
)
);
logger.info(`Disabled login page branding for org ${orgId}`);
}
}
async function disableLoginPageDomain(orgId: string): Promise<void> {
const [existingLoginPage] = await db
.select()
.from(loginPageOrg)
.where(eq(loginPageOrg.orgId, orgId))
.innerJoin(
loginPage,
eq(loginPage.loginPageId, loginPageOrg.loginPageId)
);
if (existingLoginPage) {
await db
.delete(loginPageOrg)
.where(eq(loginPageOrg.orgId, orgId));
await db
.delete(loginPage)
.where(
eq(
loginPage.loginPageId,
existingLoginPage.loginPageOrg.loginPageId
)
);
logger.info(`Disabled login page domain for org ${orgId}`);
}
}
async function disableLogExport(orgId: string): Promise<void> {}
async function disableAccessLogs(orgId: string): Promise<void> {
await db
.update(orgs)
.set({ settingsLogRetentionDaysAccess: 0 })
.where(eq(orgs.orgId, orgId));
logger.info(`Disabled access logs for org ${orgId}`);
}
async function disableActionLogs(orgId: string): Promise<void> {
await db
.update(orgs)
.set({ settingsLogRetentionDaysAction: 0 })
.where(eq(orgs.orgId, orgId));
logger.info(`Disabled action logs for org ${orgId}`);
}
async function disableRotateCredentials(orgId: string): Promise<void> {}
async function disableMaintencePage(orgId: string): Promise<void> {
await db
.update(resources)
.set({
maintenanceModeEnabled: false
})
.where(eq(resources.orgId, orgId));
logger.info(`Disabled maintenance page on all resources for org ${orgId}`);
}
async function disableDevicePosture(orgId: string): Promise<void> {}
async function disableTwoFactorEnforcement(orgId: string): Promise<void> {
await db
.update(orgs)
.set({ requireTwoFactor: false })
.where(eq(orgs.orgId, orgId));
logger.info(`Disabled two-factor enforcement for org ${orgId}`);
}
async function disableSessionDurationPolicies(orgId: string): Promise<void> {
await db
.update(orgs)
.set({ maxSessionLengthHours: null })
.where(eq(orgs.orgId, orgId));
logger.info(`Disabled session duration policies for org ${orgId}`);
}
async function disablePasswordExpirationPolicies(orgId: string): Promise<void> {
await db
.update(orgs)
.set({ passwordExpiryDays: null })
.where(eq(orgs.orgId, orgId));
logger.info(`Disabled password expiration policies for org ${orgId}`);
}
async function disableAutoProvisioning(orgId: string): Promise<void> {
// Get all IDP IDs for this org through the idpOrg join table
const orgIdps = await db
.select({ idpId: idpOrg.idpId })
.from(idpOrg)
.where(eq(idpOrg.orgId, orgId));
// Update autoProvision to false for all IDPs in this org
for (const { idpId } of orgIdps) {
await db
.update(idp)
.set({ autoProvision: false })
.where(eq(idp.idpId, idpId));
}
}

View File

@@ -23,8 +23,6 @@ import logger from "@server/logger";
import { fromZodError } from "zod-validation-error"; import { fromZodError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { GetOrgSubscriptionResponse } from "@server/routers/billing/types"; import { GetOrgSubscriptionResponse } from "@server/routers/billing/types";
import { usageService } from "@server/lib/billing/usageService";
import { build } from "@server/build";
// Import tables for billing // Import tables for billing
import { import {
@@ -72,19 +70,9 @@ export async function getOrgSubscriptions(
throw err; throw err;
} }
let limitsExceeded = false;
if (build === "saas") {
try {
limitsExceeded = await usageService.checkLimitSet(orgId);
} catch (err) {
logger.error("Error checking limits for org %s: %s", orgId, err);
}
}
return response<GetOrgSubscriptionResponse>(res, { return response<GetOrgSubscriptionResponse>(res, {
data: { data: {
subscriptions, subscriptions
...(build === "saas" ? { limitsExceeded } : {})
}, },
success: true, success: true,
error: false, error: false,

View File

@@ -78,10 +78,16 @@ export async function getOrgUsage(
// Get usage for org // Get usage for org
const usageData = []; const usageData = [];
const sites = await usageService.getUsage(orgId, FeatureId.SITES); const siteUptime = await usageService.getUsage(
const users = await usageService.getUsage(orgId, FeatureId.USERS); orgId,
const domains = await usageService.getUsage(orgId, FeatureId.DOMAINS); FeatureId.SITE_UPTIME
const remoteExitNodes = await usageService.getUsage( );
const users = await usageService.getUsageDaily(orgId, FeatureId.USERS);
const domains = await usageService.getUsageDaily(
orgId,
FeatureId.DOMAINS
);
const remoteExitNodes = await usageService.getUsageDaily(
orgId, orgId,
FeatureId.REMOTE_EXIT_NODES FeatureId.REMOTE_EXIT_NODES
); );
@@ -90,8 +96,8 @@ export async function getOrgUsage(
FeatureId.EGRESS_DATA_MB FeatureId.EGRESS_DATA_MB
); );
if (sites) { if (siteUptime) {
usageData.push(sites); usageData.push(siteUptime);
} }
if (users) { if (users) {
usageData.push(users); usageData.push(users);

View File

@@ -1,62 +1,35 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { import {
getLicensePriceSet, getLicensePriceSet,
} from "@server/lib/billing/licenses"; } from "@server/lib/billing/licenses";
import { import {
getHomeLabFeaturePriceSet, getTierPriceSet,
getStarterFeaturePriceSet, } from "@server/lib/billing/tiers";
getScaleFeaturePriceSet,
} from "@server/lib/billing/features";
import Stripe from "stripe"; import Stripe from "stripe";
import { Tier } from "@server/types/Tiers";
export type SubscriptionType = Tier | "license"; export function getSubType(fullSubscription: Stripe.Response<Stripe.Subscription>): "saas" | "license" {
export function getSubType(fullSubscription: Stripe.Response<Stripe.Subscription>): SubscriptionType | null {
// Determine subscription type by checking subscription items // Determine subscription type by checking subscription items
if (!Array.isArray(fullSubscription.items?.data) || fullSubscription.items.data.length === 0) { let type: "saas" | "license" = "saas";
return null; if (Array.isArray(fullSubscription.items?.data)) {
} for (const item of fullSubscription.items.data) {
const priceId = item.price.id;
for (const item of fullSubscription.items.data) { // Check if price ID matches any license price
const priceId = item.price.id; const licensePrices = Object.values(getLicensePriceSet());
// Check if price ID matches any license price if (licensePrices.includes(priceId)) {
const licensePrices = Object.values(getLicensePriceSet()); type = "license";
if (licensePrices.includes(priceId)) { break;
return "license"; }
}
// Check if price ID matches home lab tier // Check if price ID matches any tier price (saas)
const homeLabPrices = Object.values(getHomeLabFeaturePriceSet()); const tierPrices = Object.values(getTierPriceSet());
if (homeLabPrices.includes(priceId)) {
return "tier1";
}
// Check if price ID matches tier2 tier if (tierPrices.includes(priceId)) {
const tier2Prices = Object.values(getStarterFeaturePriceSet()); type = "saas";
if (tier2Prices.includes(priceId)) { break;
return "tier2"; }
}
// Check if price ID matches tier3 tier
const tier3Prices = Object.values(getScaleFeaturePriceSet());
if (tier3Prices.includes(priceId)) {
return "tier3";
} }
} }
return null; return type;
} }

View File

@@ -31,8 +31,6 @@ import { getLicensePriceSet, LicenseId } from "@server/lib/billing/licenses";
import { sendEmail } from "@server/emails"; import { sendEmail } from "@server/emails";
import EnterpriseEditionKeyGenerated from "@server/emails/templates/EnterpriseEditionKeyGenerated"; import EnterpriseEditionKeyGenerated from "@server/emails/templates/EnterpriseEditionKeyGenerated";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { getFeatureIdByPriceId } from "@server/lib/billing/features";
import { handleTierChange } from "../featureLifecycle";
export async function handleSubscriptionCreated( export async function handleSubscriptionCreated(
subscription: Stripe.Subscription subscription: Stripe.Subscription
@@ -61,8 +59,6 @@ export async function handleSubscriptionCreated(
return; return;
} }
const type = getSubType(fullSubscription);
const newSubscription = { const newSubscription = {
subscriptionId: subscription.id, subscriptionId: subscription.id,
customerId: subscription.customer as string, customerId: subscription.customer as string,
@@ -70,9 +66,7 @@ export async function handleSubscriptionCreated(
canceledAt: subscription.canceled_at canceledAt: subscription.canceled_at
? subscription.canceled_at ? subscription.canceled_at
: null, : null,
createdAt: subscription.created, createdAt: subscription.created
type: type,
version: 1 // we are hardcoding the initial version when the subscription is created, and then we will increment it on every update
}; };
await db.insert(subscriptions).values(newSubscription); await db.insert(subscriptions).values(newSubscription);
@@ -93,15 +87,10 @@ export async function handleSubscriptionCreated(
name = product.name || null; name = product.name || null;
} }
// Get the feature ID from the price ID
const featureId = getFeatureIdByPriceId(item.price.id);
return { return {
stripeSubscriptionItemId: item.id,
subscriptionId: subscription.id, subscriptionId: subscription.id,
planId: item.plan.id, planId: item.plan.id,
priceId: item.price.id, priceId: item.price.id,
featureId: featureId || null,
meterId: item.plan.meter, meterId: item.plan.meter,
unitAmount: item.price.unit_amount || 0, unitAmount: item.price.unit_amount || 0,
currentPeriodStart: item.current_period_start, currentPeriodStart: item.current_period_start,
@@ -140,23 +129,17 @@ export async function handleSubscriptionCreated(
return; return;
} }
if (type === "tier1" || type === "tier2" || type === "tier3") { const type = getSubType(fullSubscription);
if (type === "saas") {
logger.debug( logger.debug(
`Handling SAAS subscription lifecycle for org ${customer.orgId} with type ${type}` `Handling SAAS subscription lifecycle for org ${customer.orgId}`
); );
// we only need to handle the limit lifecycle for saas subscriptions not for the licenses // we only need to handle the limit lifecycle for saas subscriptions not for the licenses
await handleSubscriptionLifesycle( await handleSubscriptionLifesycle(
customer.orgId, customer.orgId,
subscription.status, subscription.status
type
); );
// Handle initial tier setup - disable features not available in this tier
logger.info(
`Setting up initial tier features for org ${customer.orgId} with type ${type}`
);
await handleTierChange(customer.orgId, type);
const [orgUserRes] = await db const [orgUserRes] = await db
.select() .select()
.from(userOrgs) .from(userOrgs)

View File

@@ -27,7 +27,6 @@ import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
import { getSubType } from "./getSubType"; import { getSubType } from "./getSubType";
import stripe from "#private/lib/stripe"; import stripe from "#private/lib/stripe";
import privateConfig from "#private/lib/config"; import privateConfig from "#private/lib/config";
import { handleTierChange } from "../featureLifecycle";
export async function handleSubscriptionDeleted( export async function handleSubscriptionDeleted(
subscription: Stripe.Subscription subscription: Stripe.Subscription
@@ -77,23 +76,16 @@ export async function handleSubscriptionDeleted(
} }
const type = getSubType(fullSubscription); const type = getSubType(fullSubscription);
if (type == "tier1" || type == "tier2" || type == "tier3") { if (type === "saas") {
logger.debug( logger.debug(
`Handling SaaS subscription deletion for orgId ${customer.orgId} and subscription ID ${subscription.id}` `Handling SaaS subscription deletion for orgId ${customer.orgId} and subscription ID ${subscription.id}`
); );
await handleSubscriptionLifesycle( await handleSubscriptionLifesycle(
customer.orgId, customer.orgId,
subscription.status, subscription.status
type
); );
// Handle feature lifecycle for cancellation - disable all tier-specific features
logger.info(
`Disabling tier-specific features for org ${customer.orgId} due to subscription deletion`
);
await handleTierChange(customer.orgId, null, type);
const [orgUserRes] = await db const [orgUserRes] = await db
.select() .select()
.from(userOrgs) .from(userOrgs)

View File

@@ -23,12 +23,11 @@ import {
} from "@server/db"; } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
import { getFeatureIdByMetricId, getFeatureIdByPriceId } from "@server/lib/billing/features"; import { getFeatureIdByMetricId } from "@server/lib/billing/features";
import stripe from "#private/lib/stripe"; import stripe from "#private/lib/stripe";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle"; import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
import { getSubType, SubscriptionType } from "./getSubType"; import { getSubType } from "./getSubType";
import privateConfig from "#private/lib/config"; import privateConfig from "#private/lib/config";
import { handleTierChange } from "../featureLifecycle";
export async function handleSubscriptionUpdated( export async function handleSubscriptionUpdated(
subscription: Stripe.Subscription, subscription: Stripe.Subscription,
@@ -65,9 +64,6 @@ export async function handleSubscriptionUpdated(
.where(eq(customers.customerId, subscription.customer as string)) .where(eq(customers.customerId, subscription.customer as string))
.limit(1); .limit(1);
const type = getSubType(fullSubscription);
const previousType = existingSubscription.type as SubscriptionType | null;
await db await db
.update(subscriptions) .update(subscriptions)
.set({ .set({
@@ -76,55 +72,25 @@ export async function handleSubscriptionUpdated(
? subscription.canceled_at ? subscription.canceled_at
: null, : null,
updatedAt: Math.floor(Date.now() / 1000), updatedAt: Math.floor(Date.now() / 1000),
billingCycleAnchor: subscription.billing_cycle_anchor, billingCycleAnchor: subscription.billing_cycle_anchor
type: type
}) })
.where(eq(subscriptions.subscriptionId, subscription.id)); .where(eq(subscriptions.subscriptionId, subscription.id));
// Handle tier change if the subscription type changed
if (type && type !== previousType) {
logger.info(
`Tier change detected for org ${customer.orgId}: ${previousType} -> ${type}`
);
await handleTierChange(customer.orgId, type, previousType ?? undefined);
}
// Upsert subscription items // Upsert subscription items
if (Array.isArray(fullSubscription.items?.data)) { if (Array.isArray(fullSubscription.items?.data)) {
// First, get existing items to preserve featureId when there's no match const itemsToUpsert = fullSubscription.items.data.map((item) => ({
const existingItems = await db subscriptionId: subscription.id,
.select() planId: item.plan.id,
.from(subscriptionItems) priceId: item.price.id,
.where(eq(subscriptionItems.subscriptionId, subscription.id)); meterId: item.plan.meter,
unitAmount: item.price.unit_amount || 0,
const itemsToUpsert = fullSubscription.items.data.map((item) => { currentPeriodStart: item.current_period_start,
// Try to get featureId from price currentPeriodEnd: item.current_period_end,
let featureId: string | null = getFeatureIdByPriceId(item.price.id) || null; tiers: item.price.tiers
? JSON.stringify(item.price.tiers)
// If no match, try to preserve existing featureId : null,
if (!featureId) { interval: item.plan.interval
const existingItem = existingItems.find( }));
(ei) => ei.stripeSubscriptionItemId === item.id
);
featureId = existingItem?.featureId || null;
}
return {
stripeSubscriptionItemId: item.id,
subscriptionId: subscription.id,
planId: item.plan.id,
priceId: item.price.id,
featureId: featureId,
meterId: item.plan.meter,
unitAmount: item.price.unit_amount || 0,
currentPeriodStart: item.current_period_start,
currentPeriodEnd: item.current_period_end,
tiers: item.price.tiers
? JSON.stringify(item.price.tiers)
: null,
interval: item.plan.interval
};
});
if (itemsToUpsert.length > 0) { if (itemsToUpsert.length > 0) {
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
@@ -188,7 +154,7 @@ export async function handleSubscriptionUpdated(
const orgId = customer.orgId; const orgId = customer.orgId;
if (!orgId) { if (!orgId) {
logger.debug( logger.warn(
`No orgId found in subscription metadata for subscription ${subscription.id}. Skipping usage reset.` `No orgId found in subscription metadata for subscription ${subscription.id}. Skipping usage reset.`
); );
continue; continue;
@@ -268,29 +234,17 @@ export async function handleSubscriptionUpdated(
} }
// --- end usage update --- // --- end usage update ---
if (type === "tier1" || type === "tier2" || type === "tier3") { const type = getSubType(fullSubscription);
if (type === "saas") {
logger.debug( logger.debug(
`Handling SAAS subscription lifecycle for org ${customer.orgId} with type ${type}` `Handling SAAS subscription lifecycle for org ${customer.orgId}`
); );
// we only need to handle the limit lifecycle for saas subscriptions not for the licenses // we only need to handle the limit lifecycle for saas subscriptions not for the licenses
await handleSubscriptionLifesycle( await handleSubscriptionLifesycle(
customer.orgId, customer.orgId,
subscription.status, subscription.status
type
); );
} else {
// Handle feature lifecycle when subscription is canceled or becomes unpaid
if (
subscription.status === "canceled" ||
subscription.status === "unpaid" ||
subscription.status === "incomplete_expired"
) {
logger.info(
`Subscription ${subscription.id} for org ${customer.orgId} is ${subscription.status}, disabling paid features`
);
await handleTierChange(customer.orgId, null, previousType ?? undefined);
}
} else if (type === "license") {
if (subscription.status === "canceled" || subscription.status == "unpaid" || subscription.status == "incomplete_expired") { if (subscription.status === "canceled" || subscription.status == "unpaid" || subscription.status == "incomplete_expired") {
try { try {
// WARNING: // WARNING:

View File

@@ -11,9 +11,8 @@
* This file is not licensed under the AGPLv3. * This file is not licensed under the AGPLv3.
*/ */
export * from "./createCheckoutSession"; export * from "./createCheckoutSessionSAAS";
export * from "./createPortalSession"; export * from "./createPortalSession";
export * from "./getOrgSubscriptions"; export * from "./getOrgSubscriptions";
export * from "./getOrgUsage"; export * from "./getOrgUsage";
export * from "./internalGetOrgTier"; export * from "./internalGetOrgTier";
export * from "./changeTier";

View File

@@ -13,66 +13,38 @@
import { import {
freeLimitSet, freeLimitSet,
tier1LimitSet,
tier2LimitSet,
tier3LimitSet,
limitsService, limitsService,
LimitSet subscribedLimitSet
} from "@server/lib/billing"; } from "@server/lib/billing";
import { usageService } from "@server/lib/billing/usageService"; import { usageService } from "@server/lib/billing/usageService";
import { SubscriptionType } from "./hooks/getSubType"; import logger from "@server/logger";
function getLimitSetForSubscriptionType(
subType: SubscriptionType | null
): LimitSet {
switch (subType) {
case "tier1":
return tier1LimitSet;
case "tier2":
return tier2LimitSet;
case "tier3":
return tier3LimitSet;
case "license":
// License subscriptions use tier2 limits by default
// This can be adjusted based on your business logic
return tier2LimitSet;
default:
return freeLimitSet;
}
}
export async function handleSubscriptionLifesycle( export async function handleSubscriptionLifesycle(
orgId: string, orgId: string,
status: string, status: string
subType: SubscriptionType | null
) { ) {
switch (status) { switch (status) {
case "active": case "active":
const activeLimitSet = getLimitSetForSubscriptionType(subType); await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet);
await limitsService.applyLimitSetToOrg(orgId, activeLimitSet); await usageService.checkLimitSet(orgId, true);
await usageService.checkLimitSet(orgId);
break; break;
case "canceled": case "canceled":
// Subscription canceled - revert to free tier
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet); await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
await usageService.checkLimitSet(orgId); await usageService.checkLimitSet(orgId, true);
break; break;
case "past_due": case "past_due":
// Payment past due - keep current limits but notify customer // Optionally handle past due status, e.g., notify customer
// Limits will revert to free tier if it becomes unpaid
break; break;
case "unpaid": case "unpaid":
// Subscription unpaid - revert to free tier
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet); await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
await usageService.checkLimitSet(orgId); await usageService.checkLimitSet(orgId, true);
break; break;
case "incomplete": case "incomplete":
// Payment incomplete - give them time to complete payment // Optionally handle incomplete status, e.g., notify customer
break; break;
case "incomplete_expired": case "incomplete_expired":
// Payment never completed - revert to free tier
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet); await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
await usageService.checkLimitSet(orgId); await usageService.checkLimitSet(orgId, true);
break; break;
default: default:
break; break;

View File

@@ -31,8 +31,7 @@ import {
verifyUserHasAction, verifyUserHasAction,
verifyUserIsServerAdmin, verifyUserIsServerAdmin,
verifySiteAccess, verifySiteAccess,
verifyClientAccess, verifyClientAccess
verifyLimits
} from "@server/middlewares"; } from "@server/middlewares";
import { ActionsEnum } from "@server/auth/actions"; import { ActionsEnum } from "@server/auth/actions";
import { import {
@@ -53,7 +52,6 @@ import {
authenticated as a, authenticated as a,
authRouter as aa authRouter as aa
} from "@server/routers/external"; } from "@server/routers/external";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
export const authenticated = a; export const authenticated = a;
export const unauthenticated = ua; export const unauthenticated = ua;
@@ -78,9 +76,7 @@ unauthenticated.post(
authenticated.put( authenticated.put(
"/org/:orgId/idp/oidc", "/org/:orgId/idp/oidc",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.orgOidc),
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createIdp), verifyUserHasAction(ActionsEnum.createIdp),
logActionAudit(ActionsEnum.createIdp), logActionAudit(ActionsEnum.createIdp),
orgIdp.createOrgOidcIdp orgIdp.createOrgOidcIdp
@@ -89,10 +85,8 @@ authenticated.put(
authenticated.post( authenticated.post(
"/org/:orgId/idp/:idpId/oidc", "/org/:orgId/idp/:idpId/oidc",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.orgOidc),
verifyOrgAccess, verifyOrgAccess,
verifyIdpAccess, verifyIdpAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateIdp), verifyUserHasAction(ActionsEnum.updateIdp),
logActionAudit(ActionsEnum.updateIdp), logActionAudit(ActionsEnum.updateIdp),
orgIdp.updateOrgOidcIdp orgIdp.updateOrgOidcIdp
@@ -141,27 +135,35 @@ authenticated.post(
verifyValidLicense, verifyValidLicense,
verifyOrgAccess, verifyOrgAccess,
verifyCertificateAccess, verifyCertificateAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.restartCertificate), verifyUserHasAction(ActionsEnum.restartCertificate),
logActionAudit(ActionsEnum.restartCertificate), logActionAudit(ActionsEnum.restartCertificate),
certificates.restartCertificate certificates.restartCertificate
); );
if (build === "saas") { if (build === "saas") {
authenticated.post( unauthenticated.post(
"/org/:orgId/billing/create-checkout-session", "/quick-start",
verifyOrgAccess, rateLimit({
verifyUserHasAction(ActionsEnum.billing), windowMs: 15 * 60 * 1000,
logActionAudit(ActionsEnum.billing), max: 100,
billing.createCheckoutSession keyGenerator: (req) => req.path,
handler: (req, res, next) => {
const message = `We're too busy right now. Please try again later.`;
return next(
createHttpError(HttpCode.TOO_MANY_REQUESTS, message)
);
},
store: createStore()
}),
auth.quickStart
); );
authenticated.post( authenticated.post(
"/org/:orgId/billing/change-tier", "/org/:orgId/billing/create-checkout-session-saas",
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.billing), verifyUserHasAction(ActionsEnum.billing),
logActionAudit(ActionsEnum.billing), logActionAudit(ActionsEnum.billing),
billing.changeTier billing.createCheckoutSessionSAAS
); );
authenticated.post( authenticated.post(
@@ -241,7 +243,6 @@ authenticated.put(
"/org/:orgId/remote-exit-node", "/org/:orgId/remote-exit-node",
verifyValidLicense, verifyValidLicense,
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createRemoteExitNode), verifyUserHasAction(ActionsEnum.createRemoteExitNode),
logActionAudit(ActionsEnum.createRemoteExitNode), logActionAudit(ActionsEnum.createRemoteExitNode),
remoteExitNode.createRemoteExitNode remoteExitNode.createRemoteExitNode
@@ -285,9 +286,7 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/org/:orgId/login-page", "/org/:orgId/login-page",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.loginPageDomain),
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createLoginPage), verifyUserHasAction(ActionsEnum.createLoginPage),
logActionAudit(ActionsEnum.createLoginPage), logActionAudit(ActionsEnum.createLoginPage),
loginPage.createLoginPage loginPage.createLoginPage
@@ -296,10 +295,8 @@ authenticated.put(
authenticated.post( authenticated.post(
"/org/:orgId/login-page/:loginPageId", "/org/:orgId/login-page/:loginPageId",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.loginPageDomain),
verifyOrgAccess, verifyOrgAccess,
verifyLoginPageAccess, verifyLoginPageAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateLoginPage), verifyUserHasAction(ActionsEnum.updateLoginPage),
logActionAudit(ActionsEnum.updateLoginPage), logActionAudit(ActionsEnum.updateLoginPage),
loginPage.updateLoginPage loginPage.updateLoginPage
@@ -326,7 +323,6 @@ authenticated.get(
authenticated.get( authenticated.get(
"/org/:orgId/approvals", "/org/:orgId/approvals",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.deviceApprovals),
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.listApprovals), verifyUserHasAction(ActionsEnum.listApprovals),
logActionAudit(ActionsEnum.listApprovals), logActionAudit(ActionsEnum.listApprovals),
@@ -343,9 +339,7 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/approvals/:approvalId", "/org/:orgId/approvals/:approvalId",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.deviceApprovals),
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateApprovals), verifyUserHasAction(ActionsEnum.updateApprovals),
logActionAudit(ActionsEnum.updateApprovals), logActionAudit(ActionsEnum.updateApprovals),
approval.processPendingApproval approval.processPendingApproval
@@ -354,7 +348,6 @@ authenticated.put(
authenticated.get( authenticated.get(
"/org/:orgId/login-page-branding", "/org/:orgId/login-page-branding",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.loginPageBranding),
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.getLoginPage), verifyUserHasAction(ActionsEnum.getLoginPage),
logActionAudit(ActionsEnum.getLoginPage), logActionAudit(ActionsEnum.getLoginPage),
@@ -364,9 +357,7 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/login-page-branding", "/org/:orgId/login-page-branding",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.loginPageBranding),
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateLoginPage), verifyUserHasAction(ActionsEnum.updateLoginPage),
logActionAudit(ActionsEnum.updateLoginPage), logActionAudit(ActionsEnum.updateLoginPage),
loginPage.upsertLoginPageBranding loginPage.upsertLoginPageBranding
@@ -442,7 +433,7 @@ authenticated.post(
authenticated.get( authenticated.get(
"/org/:orgId/logs/action", "/org/:orgId/logs/action",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.actionLogs), verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.exportLogs), verifyUserHasAction(ActionsEnum.exportLogs),
logs.queryActionAuditLogs logs.queryActionAuditLogs
@@ -451,7 +442,7 @@ authenticated.get(
authenticated.get( authenticated.get(
"/org/:orgId/logs/action/export", "/org/:orgId/logs/action/export",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.logExport), verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.exportLogs), verifyUserHasAction(ActionsEnum.exportLogs),
logActionAudit(ActionsEnum.exportLogs), logActionAudit(ActionsEnum.exportLogs),
@@ -461,7 +452,7 @@ authenticated.get(
authenticated.get( authenticated.get(
"/org/:orgId/logs/access", "/org/:orgId/logs/access",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.accessLogs), verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.exportLogs), verifyUserHasAction(ActionsEnum.exportLogs),
logs.queryAccessAuditLogs logs.queryAccessAuditLogs
@@ -470,7 +461,7 @@ authenticated.get(
authenticated.get( authenticated.get(
"/org/:orgId/logs/access/export", "/org/:orgId/logs/access/export",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.logExport), verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.exportLogs), verifyUserHasAction(ActionsEnum.exportLogs),
logActionAudit(ActionsEnum.exportLogs), logActionAudit(ActionsEnum.exportLogs),
@@ -479,20 +470,18 @@ authenticated.get(
authenticated.post( authenticated.post(
"/re-key/:clientId/regenerate-client-secret", "/re-key/:clientId/regenerate-client-secret",
verifyValidLicense,
verifyValidSubscription(tierMatrix.rotateCredentials),
verifyClientAccess, // this is first to set the org id verifyClientAccess, // this is first to set the org id
verifyLimits, verifyValidLicense,
verifyValidSubscription,
verifyUserHasAction(ActionsEnum.reGenerateSecret), verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateClientSecret reKey.reGenerateClientSecret
); );
authenticated.post( authenticated.post(
"/re-key/:siteId/regenerate-site-secret", "/re-key/:siteId/regenerate-site-secret",
verifyValidLicense,
verifyValidSubscription(tierMatrix.rotateCredentials),
verifySiteAccess, // this is first to set the org id verifySiteAccess, // this is first to set the org id
verifyLimits, verifyValidLicense,
verifyValidSubscription,
verifyUserHasAction(ActionsEnum.reGenerateSecret), verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateSiteSecret reKey.reGenerateSiteSecret
); );
@@ -500,9 +489,8 @@ authenticated.post(
authenticated.put( authenticated.put(
"/re-key/:orgId/regenerate-remote-exit-node-secret", "/re-key/:orgId/regenerate-remote-exit-node-secret",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.rotateCredentials), verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.reGenerateSecret), verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateExitNodeSecret reKey.reGenerateExitNodeSecret
); );

View File

@@ -126,7 +126,6 @@ export async function generateNewEnterpriseLicense(
], // Start with the standard feature set that matches the free limits ], // Start with the standard feature set that matches the free limits
customer: customer.customerId, customer: customer.customerId,
mode: "subscription", mode: "subscription",
allow_promotion_codes: true,
success_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/license?success=true&session_id={CHECKOUT_SESSION_ID}`, success_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/license?success=true&session_id={CHECKOUT_SESSION_ID}`,
cancel_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/license?canceled=true` cancel_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/license?canceled=true`
}); });

View File

@@ -19,20 +19,21 @@ import {
verifyApiKeyHasAction, verifyApiKeyHasAction,
verifyApiKeyIsRoot, verifyApiKeyIsRoot,
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyApiKeyIdpAccess, verifyApiKeyIdpAccess
verifyLimits
} from "@server/middlewares"; } from "@server/middlewares";
import { import {
verifyValidSubscription, verifyValidSubscription,
verifyValidLicense verifyValidLicense
} from "#private/middlewares"; } from "#private/middlewares";
import { ActionsEnum } from "@server/auth/actions"; import { ActionsEnum } from "@server/auth/actions";
import { import {
unauthenticated as ua, unauthenticated as ua,
authenticated as a authenticated as a
} from "@server/routers/integration"; } from "@server/routers/integration";
import { logActionAudit } from "#private/middlewares"; import { logActionAudit } from "#private/middlewares";
import { tierMatrix } from "@server/lib/billing/tierMatrix"; import config from "#private/lib/config";
import { build } from "@server/build";
export const unauthenticated = ua; export const unauthenticated = ua;
export const authenticated = a; export const authenticated = a;
@@ -56,7 +57,7 @@ authenticated.delete(
authenticated.get( authenticated.get(
"/org/:orgId/logs/action", "/org/:orgId/logs/action",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.actionLogs), verifyValidSubscription,
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs), verifyApiKeyHasAction(ActionsEnum.exportLogs),
logs.queryActionAuditLogs logs.queryActionAuditLogs
@@ -65,7 +66,7 @@ authenticated.get(
authenticated.get( authenticated.get(
"/org/:orgId/logs/action/export", "/org/:orgId/logs/action/export",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.logExport), verifyValidSubscription,
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs), verifyApiKeyHasAction(ActionsEnum.exportLogs),
logActionAudit(ActionsEnum.exportLogs), logActionAudit(ActionsEnum.exportLogs),
@@ -75,7 +76,7 @@ authenticated.get(
authenticated.get( authenticated.get(
"/org/:orgId/logs/access", "/org/:orgId/logs/access",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.accessLogs), verifyValidSubscription,
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs), verifyApiKeyHasAction(ActionsEnum.exportLogs),
logs.queryAccessAuditLogs logs.queryAccessAuditLogs
@@ -84,7 +85,7 @@ authenticated.get(
authenticated.get( authenticated.get(
"/org/:orgId/logs/access/export", "/org/:orgId/logs/access/export",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.logExport), verifyValidSubscription,
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs), verifyApiKeyHasAction(ActionsEnum.exportLogs),
logActionAudit(ActionsEnum.exportLogs), logActionAudit(ActionsEnum.exportLogs),
@@ -94,9 +95,7 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/idp/oidc", "/org/:orgId/idp/oidc",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.orgOidc),
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createIdp), verifyApiKeyHasAction(ActionsEnum.createIdp),
logActionAudit(ActionsEnum.createIdp), logActionAudit(ActionsEnum.createIdp),
orgIdp.createOrgOidcIdp orgIdp.createOrgOidcIdp
@@ -105,10 +104,8 @@ authenticated.put(
authenticated.post( authenticated.post(
"/org/:orgId/idp/:idpId/oidc", "/org/:orgId/idp/:idpId/oidc",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription(tierMatrix.orgOidc),
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyApiKeyIdpAccess, verifyApiKeyIdpAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateIdp), verifyApiKeyHasAction(ActionsEnum.updateIdp),
logActionAudit(ActionsEnum.updateIdp), logActionAudit(ActionsEnum.updateIdp),
orgIdp.updateOrgOidcIdp orgIdp.updateOrgOidcIdp

View File

@@ -30,7 +30,9 @@ import { fromError } from "zod-validation-error";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import { validateAndConstructDomain } from "@server/lib/domainUtils"; import { validateAndConstructDomain } from "@server/lib/domainUtils";
import { createCertificate } from "#private/routers/certificates/createCertificate"; import { createCertificate } from "#private/routers/certificates/createCertificate";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
import { CreateLoginPageResponse } from "@server/routers/loginPage/types"; import { CreateLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z.strictObject({ const paramsSchema = z.strictObject({
@@ -74,6 +76,19 @@ export async function createLoginPage(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existing] = await db const [existing] = await db
.select() .select()
.from(loginPageOrg) .from(loginPageOrg)

View File

@@ -25,7 +25,9 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
const paramsSchema = z const paramsSchema = z
.object({ .object({
@@ -51,6 +53,18 @@ export async function deleteLoginPageBranding(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existingLoginPageBranding] = await db const [existingLoginPageBranding] = await db
.select() .select()

View File

@@ -25,7 +25,9 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
const paramsSchema = z.strictObject({ const paramsSchema = z.strictObject({
orgId: z.string() orgId: z.string()
@@ -49,6 +51,19 @@ export async function getLoginPageBranding(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existingLoginPageBranding] = await db const [existingLoginPageBranding] = await db
.select() .select()
.from(loginPageBranding) .from(loginPageBranding)

View File

@@ -23,7 +23,9 @@ import { eq, and } from "drizzle-orm";
import { validateAndConstructDomain } from "@server/lib/domainUtils"; import { validateAndConstructDomain } from "@server/lib/domainUtils";
import { subdomainSchema } from "@server/lib/schemas"; import { subdomainSchema } from "@server/lib/schemas";
import { createCertificate } from "#private/routers/certificates/createCertificate"; import { createCertificate } from "#private/routers/certificates/createCertificate";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
import { UpdateLoginPageResponse } from "@server/routers/loginPage/types"; import { UpdateLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z const paramsSchema = z
@@ -85,6 +87,18 @@ export async function updateLoginPage(
const { loginPageId, orgId } = parsedParams.data; const { loginPageId, orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existingLoginPage] = await db const [existingLoginPage] = await db
.select() .select()

View File

@@ -25,8 +25,10 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { eq, InferInsertModel } from "drizzle-orm"; import { eq, InferInsertModel } from "drizzle-orm";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build"; import { build } from "@server/build";
import config from "#private/lib/config"; import config from "@server/private/lib/config";
const paramsSchema = z.strictObject({ const paramsSchema = z.strictObject({
orgId: z.string() orgId: z.string()
@@ -126,6 +128,19 @@ export async function upsertLoginPageBranding(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
let updateData = parsedBody.data satisfies InferInsertModel< let updateData = parsedBody.data satisfies InferInsertModel<
typeof loginPageBranding typeof loginPageBranding
>; >;

View File

@@ -24,9 +24,10 @@ import { idp, idpOidcConfig, idpOrg, orgs } from "@server/db";
import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl"; import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
import { encrypt } from "@server/lib/crypto"; import { encrypt } from "@server/lib/crypto";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types"; import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types";
import { isSubscribed } from "#dynamic/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
const paramsSchema = z.strictObject({ orgId: z.string().nonempty() }); const paramsSchema = z.strictObject({ orgId: z.string().nonempty() });
@@ -102,19 +103,23 @@ export async function createOrgOidcIdp(
emailPath, emailPath,
namePath, namePath,
name, name,
autoProvision,
variant, variant,
roleMapping, roleMapping,
tags tags
} = parsedBody.data; } = parsedBody.data;
let { autoProvision } = parsedBody.data; if (build === "saas") {
const { tier, active } = await getOrgTierData(orgId);
const subscribed = await isSubscribed( const subscribed = tier === TierId.STANDARD;
orgId, if (!subscribed) {
tierMatrix.deviceApprovals return next(
); createHttpError(
if (!subscribed) { HttpCode.FORBIDDEN,
autoProvision = false; "This organization's current plan does not support this feature."
)
);
}
} }
const key = config.getRawConfig().server.secret!; const key = config.getRawConfig().server.secret!;

View File

@@ -24,8 +24,9 @@ import { idp, idpOidcConfig } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import { encrypt } from "@server/lib/crypto"; import { encrypt } from "@server/lib/crypto";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { isSubscribed } from "#dynamic/lib/isSubscribed"; import { build } from "@server/build";
import { tierMatrix } from "@server/lib/billing/tierMatrix"; import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
const paramsSchema = z const paramsSchema = z
.object({ .object({
@@ -108,18 +109,22 @@ export async function updateOrgOidcIdp(
emailPath, emailPath,
namePath, namePath,
name, name,
autoProvision,
roleMapping, roleMapping,
tags tags
} = parsedBody.data; } = parsedBody.data;
let { autoProvision } = parsedBody.data; if (build === "saas") {
const { tier, active } = await getOrgTierData(orgId);
const subscribed = await isSubscribed( const subscribed = tier === TierId.STANDARD;
orgId, if (!subscribed) {
tierMatrix.deviceApprovals return next(
); createHttpError(
if (!subscribed) { HttpCode.FORBIDDEN,
autoProvision = false; "This organization's current plan does not support this feature."
)
);
}
} }
// Check if IDP exists and is of type OIDC // Check if IDP exists and is of type OIDC

View File

@@ -85,7 +85,7 @@ export async function createRemoteExitNode(
if (usage) { if (usage) {
const rejectRemoteExitNodes = await usageService.checkLimitSet( const rejectRemoteExitNodes = await usageService.checkLimitSet(
orgId, orgId,
false,
FeatureId.REMOTE_EXIT_NODES, FeatureId.REMOTE_EXIT_NODES,
{ {
...usage, ...usage,
@@ -97,7 +97,7 @@ export async function createRemoteExitNode(
return next( return next(
createHttpError( createHttpError(
HttpCode.FORBIDDEN, HttpCode.FORBIDDEN,
"Remote node limit exceeded. Please upgrade your plan." "Remote exit node limit exceeded. Please upgrade your plan or contact us at support@pangolin.net"
) )
); );
} }
@@ -224,7 +224,7 @@ export async function createRemoteExitNode(
}); });
if (numExitNodeOrgs) { if (numExitNodeOrgs) {
await usageService.updateCount( await usageService.updateDaily(
orgId, orgId,
FeatureId.REMOTE_EXIT_NODES, FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length numExitNodeOrgs.length

View File

@@ -106,7 +106,7 @@ export async function deleteRemoteExitNode(
}); });
if (numExitNodeOrgs) { if (numExitNodeOrgs) {
await usageService.updateCount( await usageService.updateDaily(
orgId, orgId,
FeatureId.REMOTE_EXIT_NODES, FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length numExitNodeOrgs.length

View File

@@ -1,6 +1,6 @@
import { db, orgs, requestAuditLog } from "@server/db"; import { db, orgs, requestAuditLog } from "@server/db";
import logger from "@server/logger"; import logger from "@server/logger";
import { and, eq, lt, sql } from "drizzle-orm"; import { and, eq, lt } from "drizzle-orm";
import cache from "@server/lib/cache"; import cache from "@server/lib/cache";
import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs"; import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs";
import { stripPortFromHost } from "@server/lib/ip"; import { stripPortFromHost } from "@server/lib/ip";
@@ -67,27 +67,17 @@ async function flushAuditLogs() {
const logsToWrite = auditLogBuffer.splice(0, auditLogBuffer.length); const logsToWrite = auditLogBuffer.splice(0, auditLogBuffer.length);
try { try {
// Use a transaction to ensure all inserts succeed or fail together // Batch insert logs in groups of 25 to avoid overwhelming the database
// This prevents index corruption from partial writes const BATCH_DB_SIZE = 25;
await db.transaction(async (tx) => { for (let i = 0; i < logsToWrite.length; i += BATCH_DB_SIZE) {
// Batch insert logs in groups of 25 to avoid overwhelming the database const batch = logsToWrite.slice(i, i + BATCH_DB_SIZE);
const BATCH_DB_SIZE = 25; await db.insert(requestAuditLog).values(batch);
for (let i = 0; i < logsToWrite.length; i += BATCH_DB_SIZE) { }
const batch = logsToWrite.slice(i, i + BATCH_DB_SIZE);
await tx.insert(requestAuditLog).values(batch);
}
});
logger.debug(`Flushed ${logsToWrite.length} audit logs to database`); logger.debug(`Flushed ${logsToWrite.length} audit logs to database`);
} catch (error) { } catch (error) {
logger.error("Error flushing audit logs:", error); logger.error("Error flushing audit logs:", error);
// On transaction error, put logs back at the front of the buffer to retry // On error, we lose these logs - consider a fallback strategy if needed
// but only if buffer isn't too large // (e.g., write to file, or put back in buffer with retry limit)
if (auditLogBuffer.length < MAX_BUFFER_SIZE - logsToWrite.length) {
auditLogBuffer.unshift(...logsToWrite);
logger.info(`Re-queued ${logsToWrite.length} audit logs for retry`);
} else {
logger.error(`Buffer full, dropped ${logsToWrite.length} audit logs`);
}
} finally { } finally {
isFlushInProgress = false; isFlushInProgress = false;
// If buffer filled up while we were flushing, flush again // If buffer filled up while we were flushing, flush again

View File

@@ -17,7 +17,8 @@ import {
ResourceHeaderAuthExtendedCompatibility, ResourceHeaderAuthExtendedCompatibility,
ResourcePassword, ResourcePassword,
ResourcePincode, ResourcePincode,
ResourceRule ResourceRule,
resourceSessions
} from "@server/db"; } from "@server/db";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { isIpInCidr, stripPortFromHost } from "@server/lib/ip"; import { isIpInCidr, stripPortFromHost } from "@server/lib/ip";
@@ -31,6 +32,7 @@ import { fromError } from "zod-validation-error";
import { getCountryCodeForIp } from "@server/lib/geoip"; import { getCountryCodeForIp } from "@server/lib/geoip";
import { getAsnForIp } from "@server/lib/asn"; import { getAsnForIp } from "@server/lib/asn";
import { getOrgTierData } from "#dynamic/lib/billing"; import { getOrgTierData } from "#dynamic/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { verifyPassword } from "@server/auth/password"; import { verifyPassword } from "@server/auth/password";
import { import {
checkOrgAccessPolicy, checkOrgAccessPolicy,
@@ -38,9 +40,8 @@ import {
} from "#dynamic/lib/checkOrgAccessPolicy"; } from "#dynamic/lib/checkOrgAccessPolicy";
import { logRequestAudit } from "./logRequestAudit"; import { logRequestAudit } from "./logRequestAudit";
import cache from "@server/lib/cache"; import cache from "@server/lib/cache";
import semver from "semver";
import { APP_VERSION } from "@server/lib/consts"; import { APP_VERSION } from "@server/lib/consts";
import { isSubscribed } from "#private/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
const verifyResourceSessionSchema = z.object({ const verifyResourceSessionSchema = z.object({
sessions: z.record(z.string(), z.string()).optional(), sessions: z.record(z.string(), z.string()).optional(),
@@ -797,11 +798,8 @@ async function notAllowed(
) { ) {
let loginPage: LoginPage | null = null; let loginPage: LoginPage | null = null;
if (orgId) { if (orgId) {
const subscribed = await isSubscribed( const { tier } = await getOrgTierData(orgId); // returns null in oss
orgId, if (tier === TierId.STANDARD) {
tierMatrix.loginPageDomain
);
if (subscribed) {
loginPage = await getOrgLoginPage(orgId); loginPage = await getOrgLoginPage(orgId);
} }
} }
@@ -854,8 +852,8 @@ async function headerAuthChallenged(
) { ) {
let loginPage: LoginPage | null = null; let loginPage: LoginPage | null = null;
if (orgId) { if (orgId) {
const subscribed = await isSubscribed(orgId, tierMatrix.loginPageDomain); const { tier } = await getOrgTierData(orgId); // returns null in oss
if (subscribed) { if (tier === TierId.STANDARD) {
loginPage = await getOrgLoginPage(orgId); loginPage = await getOrgLoginPage(orgId);
} }
} }
@@ -1041,11 +1039,7 @@ export function isPathAllowed(pattern: string, path: string): boolean {
const MAX_RECURSION_DEPTH = 100; const MAX_RECURSION_DEPTH = 100;
// Recursive function to try different wildcard matches // Recursive function to try different wildcard matches
function matchSegments( function matchSegments(patternIndex: number, pathIndex: number, depth: number = 0): boolean {
patternIndex: number,
pathIndex: number,
depth: number = 0
): boolean {
// Check recursion depth limit // Check recursion depth limit
if (depth > MAX_RECURSION_DEPTH) { if (depth > MAX_RECURSION_DEPTH) {
logger.warn( logger.warn(
@@ -1131,11 +1125,7 @@ export function isPathAllowed(pattern: string, path: string): boolean {
logger.debug( logger.debug(
`${indent}Segment with wildcard matches: "${currentPatternPart}" matches "${currentPathPart}"` `${indent}Segment with wildcard matches: "${currentPatternPart}" matches "${currentPathPart}"`
); );
return matchSegments( return matchSegments(patternIndex + 1, pathIndex + 1, depth + 1);
patternIndex + 1,
pathIndex + 1,
depth + 1
);
} }
logger.debug( logger.debug(

View File

@@ -2,8 +2,6 @@ import { Limit, Subscription, SubscriptionItem, Usage } from "@server/db";
export type GetOrgSubscriptionResponse = { export type GetOrgSubscriptionResponse = {
subscriptions: Array<{ subscription: Subscription; items: SubscriptionItem[] }>; subscriptions: Array<{ subscription: Subscription; items: SubscriptionItem[] }>;
/** When build === saas, true if org has exceeded plan limits (sites, users, etc.) */
limitsExceeded?: boolean;
}; };
export type GetOrgUsageResponse = { export type GetOrgUsageResponse = {

View File

@@ -101,7 +101,7 @@ export async function createClient(
return next( return next(
createHttpError( createHttpError(
HttpCode.BAD_REQUEST, HttpCode.BAD_REQUEST,
"Invalid subnet format. Please provide a valid IP." "Invalid subnet format. Please provide a valid CIDR notation."
) )
); );
} }

View File

@@ -13,7 +13,6 @@ import { OpenAPITags, registry } from "@server/openApi";
import { getUserDeviceName } from "@server/db/names"; import { getUserDeviceName } from "@server/db/names";
import { build } from "@server/build"; import { build } from "@server/build";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
const getClientSchema = z.strictObject({ const getClientSchema = z.strictObject({
clientId: z clientId: z
@@ -57,29 +56,19 @@ async function query(clientId?: number, niceId?: string, orgId?: string) {
} }
type PostureData = { type PostureData = {
biometricsEnabled?: boolean | null | "-"; biometricsEnabled?: boolean | null;
diskEncrypted?: boolean | null | "-"; diskEncrypted?: boolean | null;
firewallEnabled?: boolean | null | "-"; firewallEnabled?: boolean | null;
autoUpdatesEnabled?: boolean | null | "-"; autoUpdatesEnabled?: boolean | null;
tpmAvailable?: boolean | null | "-"; tpmAvailable?: boolean | null;
windowsAntivirusEnabled?: boolean | null | "-"; windowsAntivirusEnabled?: boolean | null;
macosSipEnabled?: boolean | null | "-"; macosSipEnabled?: boolean | null;
macosGatekeeperEnabled?: boolean | null | "-"; macosGatekeeperEnabled?: boolean | null;
macosFirewallStealthMode?: boolean | null | "-"; macosFirewallStealthMode?: boolean | null;
linuxAppArmorEnabled?: boolean | null | "-"; linuxAppArmorEnabled?: boolean | null;
linuxSELinuxEnabled?: boolean | null | "-"; linuxSELinuxEnabled?: boolean | null;
}; };
function maskPostureDataWithPlaceholder(posture: PostureData): PostureData {
const masked: PostureData = {};
for (const key of Object.keys(posture) as (keyof PostureData)[]) {
if (posture[key] !== undefined && posture[key] !== null) {
(masked as Record<keyof PostureData, "-">)[key] = "-";
}
}
return masked;
}
function getPlatformPostureData( function getPlatformPostureData(
platform: string | null | undefined, platform: string | null | undefined,
fingerprint: typeof currentFingerprint.$inferSelect | null fingerprint: typeof currentFingerprint.$inferSelect | null
@@ -295,11 +284,9 @@ export async function getClient(
); );
} }
const isUserDevice = client.user !== null && client.user !== undefined;
// Replace name with device name if OLM exists // Replace name with device name if OLM exists
let clientName = client.clients.name; let clientName = client.clients.name;
if (client.olms && isUserDevice) { if (client.olms) {
const model = client.currentFingerprint?.deviceModel || null; const model = client.currentFingerprint?.deviceModel || null;
clientName = getUserDeviceName(model, client.clients.name); clientName = getUserDeviceName(model, client.clients.name);
} }
@@ -307,35 +294,32 @@ export async function getClient(
// Build fingerprint data if available // Build fingerprint data if available
const fingerprintData = client.currentFingerprint const fingerprintData = client.currentFingerprint
? { ? {
username: client.currentFingerprint.username || null, username: client.currentFingerprint.username || null,
hostname: client.currentFingerprint.hostname || null, hostname: client.currentFingerprint.hostname || null,
platform: client.currentFingerprint.platform || null, platform: client.currentFingerprint.platform || null,
osVersion: client.currentFingerprint.osVersion || null, osVersion: client.currentFingerprint.osVersion || null,
kernelVersion: kernelVersion:
client.currentFingerprint.kernelVersion || null, client.currentFingerprint.kernelVersion || null,
arch: client.currentFingerprint.arch || null, arch: client.currentFingerprint.arch || null,
deviceModel: client.currentFingerprint.deviceModel || null, deviceModel: client.currentFingerprint.deviceModel || null,
serialNumber: client.currentFingerprint.serialNumber || null, serialNumber: client.currentFingerprint.serialNumber || null,
firstSeen: client.currentFingerprint.firstSeen || null, firstSeen: client.currentFingerprint.firstSeen || null,
lastSeen: client.currentFingerprint.lastSeen || null lastSeen: client.currentFingerprint.lastSeen || null
} }
: null; : null;
// Build posture data if available (platform-specific) // Build posture data if available (platform-specific)
// Licensed: real values; not licensed: same keys but values set to "-" // Only return posture data if org is licensed/subscribed
const rawPosture = getPlatformPostureData( let postureData: PostureData | null = null;
client.currentFingerprint?.platform || null,
client.currentFingerprint
);
const isOrgLicensed = await isLicensedOrSubscribed( const isOrgLicensed = await isLicensedOrSubscribed(
client.clients.orgId, client.clients.orgId
tierMatrix.devicePosture
); );
const postureData: PostureData | null = rawPosture if (isOrgLicensed) {
? isOrgLicensed postureData = getPlatformPostureData(
? rawPosture client.currentFingerprint?.platform || null,
: maskPostureDataWithPlaceholder(rawPosture) client.currentFingerprint
: null; );
}
const data: GetClientResponse = { const data: GetClientResponse = {
...client.clients, ...client.clients,

View File

@@ -320,10 +320,7 @@ export async function listClients(
// Merge clients with their site associations and replace name with device name // Merge clients with their site associations and replace name with device name
const clientsWithSites = clientsList.map((client) => { const clientsWithSites = clientsList.map((client) => {
const model = client.deviceModel || null; const model = client.deviceModel || null;
let newName = client.name; const newName = getUserDeviceName(model, client.name);
if (filter === "user") {
newName = getUserDeviceName(model, client.name);
}
return { return {
...client, ...client,
name: newName, name: newName,

View File

@@ -131,7 +131,7 @@ export async function createOrgDomain(
} }
const rejectDomains = await usageService.checkLimitSet( const rejectDomains = await usageService.checkLimitSet(
orgId, orgId,
false,
FeatureId.DOMAINS, FeatureId.DOMAINS,
{ {
...usage, ...usage,
@@ -354,7 +354,7 @@ export async function createOrgDomain(
}); });
if (numOrgDomains) { if (numOrgDomains) {
await usageService.updateCount( await usageService.updateDaily(
orgId, orgId,
FeatureId.DOMAINS, FeatureId.DOMAINS,
numOrgDomains.length numOrgDomains.length

View File

@@ -86,7 +86,7 @@ export async function deleteAccountDomain(
}); });
if (numOrgDomains) { if (numOrgDomains) {
await usageService.updateCount( await usageService.updateDaily(
orgId, orgId,
FeatureId.DOMAINS, FeatureId.DOMAINS,
numOrgDomains.length numOrgDomains.length

View File

@@ -41,8 +41,7 @@ import {
verifyUserHasAction, verifyUserHasAction,
verifyUserIsOrgOwner, verifyUserIsOrgOwner,
verifySiteResourceAccess, verifySiteResourceAccess,
verifyOlmAccess, verifyOlmAccess
verifyLimits
} from "@server/middlewares"; } from "@server/middlewares";
import { ActionsEnum } from "@server/auth/actions"; import { ActionsEnum } from "@server/auth/actions";
import rateLimit, { ipKeyGenerator } from "express-rate-limit"; import rateLimit, { ipKeyGenerator } from "express-rate-limit";
@@ -80,7 +79,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/org/:orgId", "/org/:orgId",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateOrg), verifyUserHasAction(ActionsEnum.updateOrg),
logActionAudit(ActionsEnum.updateOrg), logActionAudit(ActionsEnum.updateOrg),
org.updateOrg org.updateOrg
@@ -163,7 +161,6 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/client", "/org/:orgId/client",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createClient), verifyUserHasAction(ActionsEnum.createClient),
logActionAudit(ActionsEnum.createClient), logActionAudit(ActionsEnum.createClient),
client.createClient client.createClient
@@ -181,7 +178,6 @@ authenticated.delete(
authenticated.post( authenticated.post(
"/client/:clientId/archive", "/client/:clientId/archive",
verifyClientAccess, verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.archiveClient), verifyUserHasAction(ActionsEnum.archiveClient),
logActionAudit(ActionsEnum.archiveClient), logActionAudit(ActionsEnum.archiveClient),
client.archiveClient client.archiveClient
@@ -190,7 +186,6 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId/unarchive", "/client/:clientId/unarchive",
verifyClientAccess, verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.unarchiveClient), verifyUserHasAction(ActionsEnum.unarchiveClient),
logActionAudit(ActionsEnum.unarchiveClient), logActionAudit(ActionsEnum.unarchiveClient),
client.unarchiveClient client.unarchiveClient
@@ -199,7 +194,6 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId/block", "/client/:clientId/block",
verifyClientAccess, verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.blockClient), verifyUserHasAction(ActionsEnum.blockClient),
logActionAudit(ActionsEnum.blockClient), logActionAudit(ActionsEnum.blockClient),
client.blockClient client.blockClient
@@ -208,7 +202,6 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId/unblock", "/client/:clientId/unblock",
verifyClientAccess, verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.unblockClient), verifyUserHasAction(ActionsEnum.unblockClient),
logActionAudit(ActionsEnum.unblockClient), logActionAudit(ActionsEnum.unblockClient),
client.unblockClient client.unblockClient
@@ -217,7 +210,6 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId", "/client/:clientId",
verifyClientAccess, // this will check if the user has access to the client verifyClientAccess, // this will check if the user has access to the client
verifyLimits,
verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client
logActionAudit(ActionsEnum.updateClient), logActionAudit(ActionsEnum.updateClient),
client.updateClient client.updateClient
@@ -232,7 +224,6 @@ authenticated.post(
authenticated.post( authenticated.post(
"/site/:siteId", "/site/:siteId",
verifySiteAccess, verifySiteAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateSite), verifyUserHasAction(ActionsEnum.updateSite),
logActionAudit(ActionsEnum.updateSite), logActionAudit(ActionsEnum.updateSite),
site.updateSite site.updateSite
@@ -282,7 +273,6 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/site-resource", "/org/:orgId/site-resource",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createSiteResource), verifyUserHasAction(ActionsEnum.createSiteResource),
logActionAudit(ActionsEnum.createSiteResource), logActionAudit(ActionsEnum.createSiteResource),
siteResource.createSiteResource siteResource.createSiteResource
@@ -313,7 +303,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/site-resource/:siteResourceId", "/site-resource/:siteResourceId",
verifySiteResourceAccess, verifySiteResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateSiteResource), verifyUserHasAction(ActionsEnum.updateSiteResource),
logActionAudit(ActionsEnum.updateSiteResource), logActionAudit(ActionsEnum.updateSiteResource),
siteResource.updateSiteResource siteResource.updateSiteResource
@@ -352,7 +341,6 @@ authenticated.post(
"/site-resource/:siteResourceId/roles", "/site-resource/:siteResourceId/roles",
verifySiteResourceAccess, verifySiteResourceAccess,
verifyRoleAccess, verifyRoleAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceRoles), verifyUserHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles), logActionAudit(ActionsEnum.setResourceRoles),
siteResource.setSiteResourceRoles siteResource.setSiteResourceRoles
@@ -362,7 +350,6 @@ authenticated.post(
"/site-resource/:siteResourceId/users", "/site-resource/:siteResourceId/users",
verifySiteResourceAccess, verifySiteResourceAccess,
verifySetResourceUsers, verifySetResourceUsers,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers), verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceUsers siteResource.setSiteResourceUsers
@@ -372,7 +359,6 @@ authenticated.post(
"/site-resource/:siteResourceId/clients", "/site-resource/:siteResourceId/clients",
verifySiteResourceAccess, verifySiteResourceAccess,
verifySetResourceClients, verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers), verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceClients siteResource.setSiteResourceClients
@@ -382,7 +368,6 @@ authenticated.post(
"/site-resource/:siteResourceId/clients/add", "/site-resource/:siteResourceId/clients/add",
verifySiteResourceAccess, verifySiteResourceAccess,
verifySetResourceClients, verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers), verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.addClientToSiteResource siteResource.addClientToSiteResource
@@ -392,7 +377,6 @@ authenticated.post(
"/site-resource/:siteResourceId/clients/remove", "/site-resource/:siteResourceId/clients/remove",
verifySiteResourceAccess, verifySiteResourceAccess,
verifySetResourceClients, verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers), verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.removeClientFromSiteResource siteResource.removeClientFromSiteResource
@@ -401,7 +385,6 @@ authenticated.post(
authenticated.put( authenticated.put(
"/org/:orgId/resource", "/org/:orgId/resource",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createResource), verifyUserHasAction(ActionsEnum.createResource),
logActionAudit(ActionsEnum.createResource), logActionAudit(ActionsEnum.createResource),
resource.createResource resource.createResource
@@ -516,7 +499,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/resource/:resourceId", "/resource/:resourceId",
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateResource), verifyUserHasAction(ActionsEnum.updateResource),
logActionAudit(ActionsEnum.updateResource), logActionAudit(ActionsEnum.updateResource),
resource.updateResource resource.updateResource
@@ -532,7 +514,6 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/resource/:resourceId/target", "/resource/:resourceId/target",
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createTarget), verifyUserHasAction(ActionsEnum.createTarget),
logActionAudit(ActionsEnum.createTarget), logActionAudit(ActionsEnum.createTarget),
target.createTarget target.createTarget
@@ -547,7 +528,6 @@ authenticated.get(
authenticated.put( authenticated.put(
"/resource/:resourceId/rule", "/resource/:resourceId/rule",
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createResourceRule), verifyUserHasAction(ActionsEnum.createResourceRule),
logActionAudit(ActionsEnum.createResourceRule), logActionAudit(ActionsEnum.createResourceRule),
resource.createResourceRule resource.createResourceRule
@@ -561,7 +541,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/resource/:resourceId/rule/:ruleId", "/resource/:resourceId/rule/:ruleId",
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateResourceRule), verifyUserHasAction(ActionsEnum.updateResourceRule),
logActionAudit(ActionsEnum.updateResourceRule), logActionAudit(ActionsEnum.updateResourceRule),
resource.updateResourceRule resource.updateResourceRule
@@ -583,7 +562,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/target/:targetId", "/target/:targetId",
verifyTargetAccess, verifyTargetAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateTarget), verifyUserHasAction(ActionsEnum.updateTarget),
logActionAudit(ActionsEnum.updateTarget), logActionAudit(ActionsEnum.updateTarget),
target.updateTarget target.updateTarget
@@ -599,7 +577,6 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/org/:orgId/role", "/org/:orgId/role",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createRole), verifyUserHasAction(ActionsEnum.createRole),
logActionAudit(ActionsEnum.createRole), logActionAudit(ActionsEnum.createRole),
role.createRole role.createRole
@@ -614,7 +591,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/role/:roleId", "/role/:roleId",
verifyRoleAccess, verifyRoleAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateRole), verifyUserHasAction(ActionsEnum.updateRole),
logActionAudit(ActionsEnum.updateRole), logActionAudit(ActionsEnum.updateRole),
role.updateRole role.updateRole
@@ -643,7 +619,6 @@ authenticated.post(
"/role/:roleId/add/:userId", "/role/:roleId/add/:userId",
verifyRoleAccess, verifyRoleAccess,
verifyUserAccess, verifyUserAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.addUserRole), verifyUserHasAction(ActionsEnum.addUserRole),
logActionAudit(ActionsEnum.addUserRole), logActionAudit(ActionsEnum.addUserRole),
user.addUserRole user.addUserRole
@@ -653,7 +628,6 @@ authenticated.post(
"/resource/:resourceId/roles", "/resource/:resourceId/roles",
verifyResourceAccess, verifyResourceAccess,
verifyRoleAccess, verifyRoleAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceRoles), verifyUserHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles), logActionAudit(ActionsEnum.setResourceRoles),
resource.setResourceRoles resource.setResourceRoles
@@ -663,7 +637,6 @@ authenticated.post(
"/resource/:resourceId/users", "/resource/:resourceId/users",
verifyResourceAccess, verifyResourceAccess,
verifySetResourceUsers, verifySetResourceUsers,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers), verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
resource.setResourceUsers resource.setResourceUsers
@@ -672,7 +645,6 @@ authenticated.post(
authenticated.post( authenticated.post(
`/resource/:resourceId/password`, `/resource/:resourceId/password`,
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourcePassword), verifyUserHasAction(ActionsEnum.setResourcePassword),
logActionAudit(ActionsEnum.setResourcePassword), logActionAudit(ActionsEnum.setResourcePassword),
resource.setResourcePassword resource.setResourcePassword
@@ -681,7 +653,6 @@ authenticated.post(
authenticated.post( authenticated.post(
`/resource/:resourceId/pincode`, `/resource/:resourceId/pincode`,
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourcePincode), verifyUserHasAction(ActionsEnum.setResourcePincode),
logActionAudit(ActionsEnum.setResourcePincode), logActionAudit(ActionsEnum.setResourcePincode),
resource.setResourcePincode resource.setResourcePincode
@@ -690,7 +661,6 @@ authenticated.post(
authenticated.post( authenticated.post(
`/resource/:resourceId/header-auth`, `/resource/:resourceId/header-auth`,
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceHeaderAuth), verifyUserHasAction(ActionsEnum.setResourceHeaderAuth),
logActionAudit(ActionsEnum.setResourceHeaderAuth), logActionAudit(ActionsEnum.setResourceHeaderAuth),
resource.setResourceHeaderAuth resource.setResourceHeaderAuth
@@ -699,7 +669,6 @@ authenticated.post(
authenticated.post( authenticated.post(
`/resource/:resourceId/whitelist`, `/resource/:resourceId/whitelist`,
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceWhitelist), verifyUserHasAction(ActionsEnum.setResourceWhitelist),
logActionAudit(ActionsEnum.setResourceWhitelist), logActionAudit(ActionsEnum.setResourceWhitelist),
resource.setResourceWhitelist resource.setResourceWhitelist
@@ -715,7 +684,6 @@ authenticated.get(
authenticated.post( authenticated.post(
`/resource/:resourceId/access-token`, `/resource/:resourceId/access-token`,
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.generateAccessToken), verifyUserHasAction(ActionsEnum.generateAccessToken),
logActionAudit(ActionsEnum.generateAccessToken), logActionAudit(ActionsEnum.generateAccessToken),
accessToken.generateAccessToken accessToken.generateAccessToken
@@ -806,7 +774,6 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/org/:orgId/user", "/org/:orgId/user",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createOrgUser), verifyUserHasAction(ActionsEnum.createOrgUser),
logActionAudit(ActionsEnum.createOrgUser), logActionAudit(ActionsEnum.createOrgUser),
user.createOrgUser user.createOrgUser
@@ -816,7 +783,6 @@ authenticated.post(
"/org/:orgId/user/:userId", "/org/:orgId/user/:userId",
verifyOrgAccess, verifyOrgAccess,
verifyUserAccess, verifyUserAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateOrgUser), verifyUserHasAction(ActionsEnum.updateOrgUser),
logActionAudit(ActionsEnum.updateOrgUser), logActionAudit(ActionsEnum.updateOrgUser),
user.updateOrgUser user.updateOrgUser
@@ -889,7 +855,6 @@ authenticated.post(
"/user/:userId/olm/:olmId/archive", "/user/:userId/olm/:olmId/archive",
verifyIsLoggedInUser, verifyIsLoggedInUser,
verifyOlmAccess, verifyOlmAccess,
verifyLimits,
olm.archiveUserOlm olm.archiveUserOlm
); );
@@ -1004,7 +969,6 @@ authenticated.post(
`/org/:orgId/api-key/:apiKeyId/actions`, `/org/:orgId/api-key/:apiKeyId/actions`,
verifyOrgAccess, verifyOrgAccess,
verifyApiKeyAccess, verifyApiKeyAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.setApiKeyActions), verifyUserHasAction(ActionsEnum.setApiKeyActions),
logActionAudit(ActionsEnum.setApiKeyActions), logActionAudit(ActionsEnum.setApiKeyActions),
apiKeys.setApiKeyActions apiKeys.setApiKeyActions
@@ -1021,7 +985,6 @@ authenticated.get(
authenticated.put( authenticated.put(
`/org/:orgId/api-key`, `/org/:orgId/api-key`,
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createApiKey), verifyUserHasAction(ActionsEnum.createApiKey),
logActionAudit(ActionsEnum.createApiKey), logActionAudit(ActionsEnum.createApiKey),
apiKeys.createOrgApiKey apiKeys.createOrgApiKey
@@ -1047,7 +1010,6 @@ authenticated.get(
authenticated.put( authenticated.put(
`/org/:orgId/domain`, `/org/:orgId/domain`,
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createOrgDomain), verifyUserHasAction(ActionsEnum.createOrgDomain),
logActionAudit(ActionsEnum.createOrgDomain), logActionAudit(ActionsEnum.createOrgDomain),
domain.createOrgDomain domain.createOrgDomain
@@ -1057,7 +1019,6 @@ authenticated.post(
`/org/:orgId/domain/:domainId/restart`, `/org/:orgId/domain/:domainId/restart`,
verifyOrgAccess, verifyOrgAccess,
verifyDomainAccess, verifyDomainAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.restartOrgDomain), verifyUserHasAction(ActionsEnum.restartOrgDomain),
logActionAudit(ActionsEnum.restartOrgDomain), logActionAudit(ActionsEnum.restartOrgDomain),
domain.restartOrgDomain domain.restartOrgDomain
@@ -1104,7 +1065,6 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/blueprint", "/org/:orgId/blueprint",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.applyBlueprint), verifyUserHasAction(ActionsEnum.applyBlueprint),
blueprints.applyYAMLBlueprint blueprints.applyYAMLBlueprint
); );

View File

@@ -114,6 +114,7 @@ export async function updateSiteBandwidth(
// Aggregate usage data by organization (collected outside transaction) // Aggregate usage data by organization (collected outside transaction)
const orgUsageMap = new Map<string, number>(); const orgUsageMap = new Map<string, number>();
const orgUptimeMap = new Map<string, number>();
if (activePeers.length > 0) { if (activePeers.length > 0) {
// Remove any active peers from offline tracking since they're sending data // Remove any active peers from offline tracking since they're sending data
@@ -165,6 +166,14 @@ export async function updateSiteBandwidth(
updatedSite.orgId, updatedSite.orgId,
currentOrgUsage + totalBandwidth currentOrgUsage + totalBandwidth
); );
// Add 10 seconds of uptime for each active site
const currentOrgUptime =
orgUptimeMap.get(updatedSite.orgId) || 0;
orgUptimeMap.set(
updatedSite.orgId,
currentOrgUptime + 10 / 60
);
} }
} catch (error) { } catch (error) {
logger.error( logger.error(
@@ -178,9 +187,11 @@ export async function updateSiteBandwidth(
// Process usage updates outside of site update transactions // Process usage updates outside of site update transactions
// This separates the concerns and reduces lock contention // This separates the concerns and reduces lock contention
if (calcUsageAndLimits && orgUsageMap.size > 0) { if (calcUsageAndLimits && (orgUsageMap.size > 0 || orgUptimeMap.size > 0)) {
// Sort org IDs to ensure consistent lock ordering // Sort org IDs to ensure consistent lock ordering
const allOrgIds = [...new Set([...orgUsageMap.keys()])].sort(); const allOrgIds = [
...new Set([...orgUsageMap.keys(), ...orgUptimeMap.keys()])
].sort();
for (const orgId of allOrgIds) { for (const orgId of allOrgIds) {
try { try {
@@ -197,7 +208,7 @@ export async function updateSiteBandwidth(
usageService usageService
.checkLimitSet( .checkLimitSet(
orgId, orgId,
true,
FeatureId.EGRESS_DATA_MB, FeatureId.EGRESS_DATA_MB,
bandwidthUsage bandwidthUsage
) )
@@ -209,6 +220,32 @@ export async function updateSiteBandwidth(
}); });
} }
} }
// Process uptime usage for this org
const totalUptime = orgUptimeMap.get(orgId);
if (totalUptime) {
const uptimeUsage = await usageService.add(
orgId,
FeatureId.SITE_UPTIME,
totalUptime
);
if (uptimeUsage) {
// Fire and forget - don't block on limit checking
usageService
.checkLimitSet(
orgId,
true,
FeatureId.SITE_UPTIME,
uptimeUsage
)
.catch((error: any) => {
logger.error(
`Error checking uptime limits for org ${orgId}:`,
error
);
});
}
}
} catch (error) { } catch (error) {
logger.error(`Error processing usage for org ${orgId}:`, error); logger.error(`Error processing usage for org ${orgId}:`, error);
// Continue with other orgs // Continue with other orgs

View File

@@ -93,9 +93,7 @@ export async function createOidcIdp(
name, name,
autoProvision, autoProvision,
type: "oidc", type: "oidc",
tags, tags
defaultOrgMapping: `'{{orgId}}'`,
defaultRoleMapping: `'Member'`
}) })
.returning(); .returning();

View File

@@ -14,8 +14,8 @@ import jsonwebtoken from "jsonwebtoken";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { decrypt } from "@server/lib/crypto"; import { decrypt } from "@server/lib/crypto";
import { build } from "@server/build"; import { build } from "@server/build";
import { isSubscribed } from "#dynamic/lib/isSubscribed"; import { getOrgTierData } from "#dynamic/lib/billing";
import { tierMatrix } from "@server/lib/billing/tierMatrix"; import { TierId } from "@server/lib/billing/tiers";
const paramsSchema = z const paramsSchema = z
.object({ .object({
@@ -113,10 +113,8 @@ export async function generateOidcUrl(
} }
if (build === "saas") { if (build === "saas") {
const subscribed = await isSubscribed( const { tier } = await getOrgTierData(orgId);
orgId, const subscribed = tier === TierId.STANDARD;
tierMatrix.orgOidc
);
if (!subscribed) { if (!subscribed) {
return next( return next(
createHttpError( createHttpError(

View File

@@ -34,8 +34,6 @@ import { FeatureId } from "@server/lib/billing";
import { usageService } from "@server/lib/billing/usageService"; import { usageService } from "@server/lib/billing/usageService";
import { build } from "@server/build"; import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { isSubscribed } from "#dynamic/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
const ensureTrailingSlash = (url: string): string => { const ensureTrailingSlash = (url: string): string => {
return url; return url;
@@ -328,33 +326,6 @@ export async function validateOidcCallback(
.where(eq(idpOrg.idpId, existingIdp.idp.idpId)) .where(eq(idpOrg.idpId, existingIdp.idp.idpId))
.innerJoin(orgs, eq(orgs.orgId, idpOrg.orgId)); .innerJoin(orgs, eq(orgs.orgId, idpOrg.orgId));
allOrgs = idpOrgs.map((o) => o.orgs); allOrgs = idpOrgs.map((o) => o.orgs);
// TODO: when there are multiple orgs we need to do this better!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1
if (allOrgs.length > 1) {
// for some reason there is more than one org
logger.error(
"More than one organization linked to this IdP. This should not happen with auto-provisioning enabled."
);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Multiple organizations linked to this IdP. Please contact support."
)
);
}
const subscribed = await isSubscribed(
allOrgs[0].orgId,
tierMatrix.autoProvisioning
);
if (subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
} else { } else {
allOrgs = await db.select().from(orgs); allOrgs = await db.select().from(orgs);
} }
@@ -616,7 +587,7 @@ export async function validateOidcCallback(
}); });
for (const orgCount of orgUserCounts) { for (const orgCount of orgUserCounts) {
await usageService.updateCount( await usageService.updateDaily(
orgCount.orgId, orgCount.orgId,
FeatureId.USERS, FeatureId.USERS,
orgCount.userCount orgCount.userCount

View File

@@ -26,8 +26,7 @@ import {
verifyApiKeyIsRoot, verifyApiKeyIsRoot,
verifyApiKeyClientAccess, verifyApiKeyClientAccess,
verifyApiKeySiteResourceAccess, verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceClients, verifyApiKeySetResourceClients
verifyLimits
} from "@server/middlewares"; } from "@server/middlewares";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { Router } from "express"; import { Router } from "express";
@@ -75,7 +74,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/org/:orgId", "/org/:orgId",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateOrg), verifyApiKeyHasAction(ActionsEnum.updateOrg),
logActionAudit(ActionsEnum.updateOrg), logActionAudit(ActionsEnum.updateOrg),
org.updateOrg org.updateOrg
@@ -92,7 +90,6 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/org/:orgId/site", "/org/:orgId/site",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createSite), verifyApiKeyHasAction(ActionsEnum.createSite),
logActionAudit(ActionsEnum.createSite), logActionAudit(ActionsEnum.createSite),
site.createSite site.createSite
@@ -129,7 +126,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/site/:siteId", "/site/:siteId",
verifyApiKeySiteAccess, verifyApiKeySiteAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateSite), verifyApiKeyHasAction(ActionsEnum.updateSite),
logActionAudit(ActionsEnum.updateSite), logActionAudit(ActionsEnum.updateSite),
site.updateSite site.updateSite
@@ -150,9 +146,8 @@ authenticated.get(
); );
// Site Resource endpoints // Site Resource endpoints
authenticated.put( authenticated.put(
"/org/:orgId/site-resource", "/org/:orgId/private-resource",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createSiteResource), verifyApiKeyHasAction(ActionsEnum.createSiteResource),
logActionAudit(ActionsEnum.createSiteResource), logActionAudit(ActionsEnum.createSiteResource),
siteResource.createSiteResource siteResource.createSiteResource
@@ -183,7 +178,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/site-resource/:siteResourceId", "/site-resource/:siteResourceId",
verifyApiKeySiteResourceAccess, verifyApiKeySiteResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateSiteResource), verifyApiKeyHasAction(ActionsEnum.updateSiteResource),
logActionAudit(ActionsEnum.updateSiteResource), logActionAudit(ActionsEnum.updateSiteResource),
siteResource.updateSiteResource siteResource.updateSiteResource
@@ -222,7 +216,6 @@ authenticated.post(
"/site-resource/:siteResourceId/roles", "/site-resource/:siteResourceId/roles",
verifyApiKeySiteResourceAccess, verifyApiKeySiteResourceAccess,
verifyApiKeyRoleAccess, verifyApiKeyRoleAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceRoles), verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles), logActionAudit(ActionsEnum.setResourceRoles),
siteResource.setSiteResourceRoles siteResource.setSiteResourceRoles
@@ -232,7 +225,6 @@ authenticated.post(
"/site-resource/:siteResourceId/users", "/site-resource/:siteResourceId/users",
verifyApiKeySiteResourceAccess, verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceUsers, verifyApiKeySetResourceUsers,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers), verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceUsers siteResource.setSiteResourceUsers
@@ -242,7 +234,6 @@ authenticated.post(
"/site-resource/:siteResourceId/roles/add", "/site-resource/:siteResourceId/roles/add",
verifyApiKeySiteResourceAccess, verifyApiKeySiteResourceAccess,
verifyApiKeyRoleAccess, verifyApiKeyRoleAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceRoles), verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles), logActionAudit(ActionsEnum.setResourceRoles),
siteResource.addRoleToSiteResource siteResource.addRoleToSiteResource
@@ -252,7 +243,6 @@ authenticated.post(
"/site-resource/:siteResourceId/roles/remove", "/site-resource/:siteResourceId/roles/remove",
verifyApiKeySiteResourceAccess, verifyApiKeySiteResourceAccess,
verifyApiKeyRoleAccess, verifyApiKeyRoleAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceRoles), verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles), logActionAudit(ActionsEnum.setResourceRoles),
siteResource.removeRoleFromSiteResource siteResource.removeRoleFromSiteResource
@@ -262,7 +252,6 @@ authenticated.post(
"/site-resource/:siteResourceId/users/add", "/site-resource/:siteResourceId/users/add",
verifyApiKeySiteResourceAccess, verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceUsers, verifyApiKeySetResourceUsers,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers), verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.addUserToSiteResource siteResource.addUserToSiteResource
@@ -272,7 +261,6 @@ authenticated.post(
"/site-resource/:siteResourceId/users/remove", "/site-resource/:siteResourceId/users/remove",
verifyApiKeySiteResourceAccess, verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceUsers, verifyApiKeySetResourceUsers,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers), verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.removeUserFromSiteResource siteResource.removeUserFromSiteResource
@@ -282,7 +270,6 @@ authenticated.post(
"/site-resource/:siteResourceId/clients", "/site-resource/:siteResourceId/clients",
verifyApiKeySiteResourceAccess, verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceClients, verifyApiKeySetResourceClients,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers), verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceClients siteResource.setSiteResourceClients
@@ -292,7 +279,6 @@ authenticated.post(
"/site-resource/:siteResourceId/clients/add", "/site-resource/:siteResourceId/clients/add",
verifyApiKeySiteResourceAccess, verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceClients, verifyApiKeySetResourceClients,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers), verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.addClientToSiteResource siteResource.addClientToSiteResource
@@ -302,7 +288,6 @@ authenticated.post(
"/site-resource/:siteResourceId/clients/remove", "/site-resource/:siteResourceId/clients/remove",
verifyApiKeySiteResourceAccess, verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceClients, verifyApiKeySetResourceClients,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers), verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.removeClientFromSiteResource siteResource.removeClientFromSiteResource
@@ -311,7 +296,6 @@ authenticated.post(
authenticated.put( authenticated.put(
"/org/:orgId/resource", "/org/:orgId/resource",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createResource), verifyApiKeyHasAction(ActionsEnum.createResource),
logActionAudit(ActionsEnum.createResource), logActionAudit(ActionsEnum.createResource),
resource.createResource resource.createResource
@@ -320,7 +304,6 @@ authenticated.put(
authenticated.put( authenticated.put(
"/org/:orgId/site/:siteId/resource", "/org/:orgId/site/:siteId/resource",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createResource), verifyApiKeyHasAction(ActionsEnum.createResource),
logActionAudit(ActionsEnum.createResource), logActionAudit(ActionsEnum.createResource),
resource.createResource resource.createResource
@@ -357,7 +340,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/org/:orgId/create-invite", "/org/:orgId/create-invite",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.inviteUser), verifyApiKeyHasAction(ActionsEnum.inviteUser),
logActionAudit(ActionsEnum.inviteUser), logActionAudit(ActionsEnum.inviteUser),
user.inviteUser user.inviteUser
@@ -395,7 +377,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/resource/:resourceId", "/resource/:resourceId",
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateResource), verifyApiKeyHasAction(ActionsEnum.updateResource),
logActionAudit(ActionsEnum.updateResource), logActionAudit(ActionsEnum.updateResource),
resource.updateResource resource.updateResource
@@ -412,7 +393,6 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/resource/:resourceId/target", "/resource/:resourceId/target",
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createTarget), verifyApiKeyHasAction(ActionsEnum.createTarget),
logActionAudit(ActionsEnum.createTarget), logActionAudit(ActionsEnum.createTarget),
target.createTarget target.createTarget
@@ -428,7 +408,6 @@ authenticated.get(
authenticated.put( authenticated.put(
"/resource/:resourceId/rule", "/resource/:resourceId/rule",
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createResourceRule), verifyApiKeyHasAction(ActionsEnum.createResourceRule),
logActionAudit(ActionsEnum.createResourceRule), logActionAudit(ActionsEnum.createResourceRule),
resource.createResourceRule resource.createResourceRule
@@ -444,7 +423,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/resource/:resourceId/rule/:ruleId", "/resource/:resourceId/rule/:ruleId",
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateResourceRule), verifyApiKeyHasAction(ActionsEnum.updateResourceRule),
logActionAudit(ActionsEnum.updateResourceRule), logActionAudit(ActionsEnum.updateResourceRule),
resource.updateResourceRule resource.updateResourceRule
@@ -468,7 +446,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/target/:targetId", "/target/:targetId",
verifyApiKeyTargetAccess, verifyApiKeyTargetAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateTarget), verifyApiKeyHasAction(ActionsEnum.updateTarget),
logActionAudit(ActionsEnum.updateTarget), logActionAudit(ActionsEnum.updateTarget),
target.updateTarget target.updateTarget
@@ -485,7 +462,6 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/org/:orgId/role", "/org/:orgId/role",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createRole), verifyApiKeyHasAction(ActionsEnum.createRole),
logActionAudit(ActionsEnum.createRole), logActionAudit(ActionsEnum.createRole),
role.createRole role.createRole
@@ -494,7 +470,6 @@ authenticated.put(
authenticated.post( authenticated.post(
"/role/:roleId", "/role/:roleId",
verifyApiKeyRoleAccess, verifyApiKeyRoleAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateRole), verifyApiKeyHasAction(ActionsEnum.updateRole),
logActionAudit(ActionsEnum.updateRole), logActionAudit(ActionsEnum.updateRole),
role.updateRole role.updateRole
@@ -526,7 +501,6 @@ authenticated.post(
"/role/:roleId/add/:userId", "/role/:roleId/add/:userId",
verifyApiKeyRoleAccess, verifyApiKeyRoleAccess,
verifyApiKeyUserAccess, verifyApiKeyUserAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.addUserRole), verifyApiKeyHasAction(ActionsEnum.addUserRole),
logActionAudit(ActionsEnum.addUserRole), logActionAudit(ActionsEnum.addUserRole),
user.addUserRole user.addUserRole
@@ -536,7 +510,6 @@ authenticated.post(
"/resource/:resourceId/roles", "/resource/:resourceId/roles",
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyApiKeyRoleAccess, verifyApiKeyRoleAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceRoles), verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles), logActionAudit(ActionsEnum.setResourceRoles),
resource.setResourceRoles resource.setResourceRoles
@@ -546,7 +519,6 @@ authenticated.post(
"/resource/:resourceId/users", "/resource/:resourceId/users",
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyApiKeySetResourceUsers, verifyApiKeySetResourceUsers,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers), verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
resource.setResourceUsers resource.setResourceUsers
@@ -556,7 +528,6 @@ authenticated.post(
"/resource/:resourceId/roles/add", "/resource/:resourceId/roles/add",
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyApiKeyRoleAccess, verifyApiKeyRoleAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceRoles), verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles), logActionAudit(ActionsEnum.setResourceRoles),
resource.addRoleToResource resource.addRoleToResource
@@ -566,7 +537,6 @@ authenticated.post(
"/resource/:resourceId/roles/remove", "/resource/:resourceId/roles/remove",
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyApiKeyRoleAccess, verifyApiKeyRoleAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceRoles), verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles), logActionAudit(ActionsEnum.setResourceRoles),
resource.removeRoleFromResource resource.removeRoleFromResource
@@ -576,7 +546,6 @@ authenticated.post(
"/resource/:resourceId/users/add", "/resource/:resourceId/users/add",
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyApiKeySetResourceUsers, verifyApiKeySetResourceUsers,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers), verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
resource.addUserToResource resource.addUserToResource
@@ -586,7 +555,6 @@ authenticated.post(
"/resource/:resourceId/users/remove", "/resource/:resourceId/users/remove",
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyApiKeySetResourceUsers, verifyApiKeySetResourceUsers,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers), verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
resource.removeUserFromResource resource.removeUserFromResource
@@ -595,7 +563,6 @@ authenticated.post(
authenticated.post( authenticated.post(
`/resource/:resourceId/password`, `/resource/:resourceId/password`,
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourcePassword), verifyApiKeyHasAction(ActionsEnum.setResourcePassword),
logActionAudit(ActionsEnum.setResourcePassword), logActionAudit(ActionsEnum.setResourcePassword),
resource.setResourcePassword resource.setResourcePassword
@@ -604,7 +571,6 @@ authenticated.post(
authenticated.post( authenticated.post(
`/resource/:resourceId/pincode`, `/resource/:resourceId/pincode`,
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourcePincode), verifyApiKeyHasAction(ActionsEnum.setResourcePincode),
logActionAudit(ActionsEnum.setResourcePincode), logActionAudit(ActionsEnum.setResourcePincode),
resource.setResourcePincode resource.setResourcePincode
@@ -613,7 +579,6 @@ authenticated.post(
authenticated.post( authenticated.post(
`/resource/:resourceId/header-auth`, `/resource/:resourceId/header-auth`,
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceHeaderAuth), verifyApiKeyHasAction(ActionsEnum.setResourceHeaderAuth),
logActionAudit(ActionsEnum.setResourceHeaderAuth), logActionAudit(ActionsEnum.setResourceHeaderAuth),
resource.setResourceHeaderAuth resource.setResourceHeaderAuth
@@ -622,7 +587,6 @@ authenticated.post(
authenticated.post( authenticated.post(
`/resource/:resourceId/whitelist`, `/resource/:resourceId/whitelist`,
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceWhitelist), verifyApiKeyHasAction(ActionsEnum.setResourceWhitelist),
logActionAudit(ActionsEnum.setResourceWhitelist), logActionAudit(ActionsEnum.setResourceWhitelist),
resource.setResourceWhitelist resource.setResourceWhitelist
@@ -631,7 +595,6 @@ authenticated.post(
authenticated.post( authenticated.post(
`/resource/:resourceId/whitelist/add`, `/resource/:resourceId/whitelist/add`,
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceWhitelist), verifyApiKeyHasAction(ActionsEnum.setResourceWhitelist),
resource.addEmailToResourceWhitelist resource.addEmailToResourceWhitelist
); );
@@ -639,7 +602,6 @@ authenticated.post(
authenticated.post( authenticated.post(
`/resource/:resourceId/whitelist/remove`, `/resource/:resourceId/whitelist/remove`,
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceWhitelist), verifyApiKeyHasAction(ActionsEnum.setResourceWhitelist),
resource.removeEmailFromResourceWhitelist resource.removeEmailFromResourceWhitelist
); );
@@ -654,7 +616,6 @@ authenticated.get(
authenticated.post( authenticated.post(
`/resource/:resourceId/access-token`, `/resource/:resourceId/access-token`,
verifyApiKeyResourceAccess, verifyApiKeyResourceAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.generateAccessToken), verifyApiKeyHasAction(ActionsEnum.generateAccessToken),
logActionAudit(ActionsEnum.generateAccessToken), logActionAudit(ActionsEnum.generateAccessToken),
accessToken.generateAccessToken accessToken.generateAccessToken
@@ -692,7 +653,6 @@ authenticated.get(
authenticated.post( authenticated.post(
"/user/:userId/2fa", "/user/:userId/2fa",
verifyApiKeyIsRoot, verifyApiKeyIsRoot,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateUser), verifyApiKeyHasAction(ActionsEnum.updateUser),
logActionAudit(ActionsEnum.updateUser), logActionAudit(ActionsEnum.updateUser),
user.updateUser2FA user.updateUser2FA
@@ -715,7 +675,6 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/user", "/org/:orgId/user",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createOrgUser), verifyApiKeyHasAction(ActionsEnum.createOrgUser),
logActionAudit(ActionsEnum.createOrgUser), logActionAudit(ActionsEnum.createOrgUser),
user.createOrgUser user.createOrgUser
@@ -725,7 +684,6 @@ authenticated.post(
"/org/:orgId/user/:userId", "/org/:orgId/user/:userId",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyApiKeyUserAccess, verifyApiKeyUserAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateOrgUser), verifyApiKeyHasAction(ActionsEnum.updateOrgUser),
logActionAudit(ActionsEnum.updateOrgUser), logActionAudit(ActionsEnum.updateOrgUser),
user.updateOrgUser user.updateOrgUser
@@ -756,7 +714,6 @@ authenticated.get(
authenticated.post( authenticated.post(
`/org/:orgId/api-key/:apiKeyId/actions`, `/org/:orgId/api-key/:apiKeyId/actions`,
verifyApiKeyIsRoot, verifyApiKeyIsRoot,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setApiKeyActions), verifyApiKeyHasAction(ActionsEnum.setApiKeyActions),
logActionAudit(ActionsEnum.setApiKeyActions), logActionAudit(ActionsEnum.setApiKeyActions),
apiKeys.setApiKeyActions apiKeys.setApiKeyActions
@@ -772,7 +729,6 @@ authenticated.get(
authenticated.put( authenticated.put(
`/org/:orgId/api-key`, `/org/:orgId/api-key`,
verifyApiKeyIsRoot, verifyApiKeyIsRoot,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createApiKey), verifyApiKeyHasAction(ActionsEnum.createApiKey),
logActionAudit(ActionsEnum.createApiKey), logActionAudit(ActionsEnum.createApiKey),
apiKeys.createOrgApiKey apiKeys.createOrgApiKey
@@ -789,7 +745,6 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/idp/oidc", "/idp/oidc",
verifyApiKeyIsRoot, verifyApiKeyIsRoot,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createIdp), verifyApiKeyHasAction(ActionsEnum.createIdp),
logActionAudit(ActionsEnum.createIdp), logActionAudit(ActionsEnum.createIdp),
idp.createOidcIdp idp.createOidcIdp
@@ -798,7 +753,6 @@ authenticated.put(
authenticated.post( authenticated.post(
"/idp/:idpId/oidc", "/idp/:idpId/oidc",
verifyApiKeyIsRoot, verifyApiKeyIsRoot,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateIdp), verifyApiKeyHasAction(ActionsEnum.updateIdp),
logActionAudit(ActionsEnum.updateIdp), logActionAudit(ActionsEnum.updateIdp),
idp.updateOidcIdp idp.updateOidcIdp
@@ -822,7 +776,6 @@ authenticated.get(
authenticated.put( authenticated.put(
"/idp/:idpId/org/:orgId", "/idp/:idpId/org/:orgId",
verifyApiKeyIsRoot, verifyApiKeyIsRoot,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createIdpOrg), verifyApiKeyHasAction(ActionsEnum.createIdpOrg),
logActionAudit(ActionsEnum.createIdpOrg), logActionAudit(ActionsEnum.createIdpOrg),
idp.createIdpOrgPolicy idp.createIdpOrgPolicy
@@ -831,7 +784,6 @@ authenticated.put(
authenticated.post( authenticated.post(
"/idp/:idpId/org/:orgId", "/idp/:idpId/org/:orgId",
verifyApiKeyIsRoot, verifyApiKeyIsRoot,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateIdpOrg), verifyApiKeyHasAction(ActionsEnum.updateIdpOrg),
logActionAudit(ActionsEnum.updateIdpOrg), logActionAudit(ActionsEnum.updateIdpOrg),
idp.updateIdpOrgPolicy idp.updateIdpOrgPolicy
@@ -876,7 +828,6 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/client", "/org/:orgId/client",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createClient), verifyApiKeyHasAction(ActionsEnum.createClient),
logActionAudit(ActionsEnum.createClient), logActionAudit(ActionsEnum.createClient),
client.createClient client.createClient
@@ -903,7 +854,6 @@ authenticated.delete(
authenticated.post( authenticated.post(
"/client/:clientId/archive", "/client/:clientId/archive",
verifyApiKeyClientAccess, verifyApiKeyClientAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.archiveClient), verifyApiKeyHasAction(ActionsEnum.archiveClient),
logActionAudit(ActionsEnum.archiveClient), logActionAudit(ActionsEnum.archiveClient),
client.archiveClient client.archiveClient
@@ -912,7 +862,6 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId/unarchive", "/client/:clientId/unarchive",
verifyApiKeyClientAccess, verifyApiKeyClientAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.unarchiveClient), verifyApiKeyHasAction(ActionsEnum.unarchiveClient),
logActionAudit(ActionsEnum.unarchiveClient), logActionAudit(ActionsEnum.unarchiveClient),
client.unarchiveClient client.unarchiveClient
@@ -921,7 +870,6 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId/block", "/client/:clientId/block",
verifyApiKeyClientAccess, verifyApiKeyClientAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.blockClient), verifyApiKeyHasAction(ActionsEnum.blockClient),
logActionAudit(ActionsEnum.blockClient), logActionAudit(ActionsEnum.blockClient),
client.blockClient client.blockClient
@@ -930,7 +878,6 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId/unblock", "/client/:clientId/unblock",
verifyApiKeyClientAccess, verifyApiKeyClientAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.unblockClient), verifyApiKeyHasAction(ActionsEnum.unblockClient),
logActionAudit(ActionsEnum.unblockClient), logActionAudit(ActionsEnum.unblockClient),
client.unblockClient client.unblockClient
@@ -939,7 +886,6 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId", "/client/:clientId",
verifyApiKeyClientAccess, verifyApiKeyClientAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateClient), verifyApiKeyHasAction(ActionsEnum.updateClient),
logActionAudit(ActionsEnum.updateClient), logActionAudit(ActionsEnum.updateClient),
client.updateClient client.updateClient
@@ -948,7 +894,6 @@ authenticated.post(
authenticated.put( authenticated.put(
"/org/:orgId/blueprint", "/org/:orgId/blueprint",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.applyBlueprint), verifyApiKeyHasAction(ActionsEnum.applyBlueprint),
logActionAudit(ActionsEnum.applyBlueprint), logActionAudit(ActionsEnum.applyBlueprint),
blueprints.applyJSONBlueprint blueprints.applyJSONBlueprint

View File

@@ -1,13 +1,17 @@
import { db, ExitNode, newts, Transaction } from "@server/db"; import { db, ExitNode, exitNodeOrgs, newts, Transaction } from "@server/db";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { exitNodes, Newt, sites } from "@server/db"; import { exitNodes, Newt, resources, sites, Target, targets } from "@server/db";
import { eq } from "drizzle-orm"; import { targetHealthCheck } from "@server/db";
import { eq, and, sql, inArray, ne } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers"; import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger"; import logger from "@server/logger";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { import {
findNextAvailableCidr, findNextAvailableCidr,
getNextAvailableClientSubnet
} from "@server/lib/ip"; } from "@server/lib/ip";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
import { import {
selectBestExitNode, selectBestExitNode,
verifyExitNodeOrgAccess verifyExitNodeOrgAccess
@@ -26,6 +30,8 @@ export type ExitNodePingResult = {
wasPreviouslyConnected: boolean; wasPreviouslyConnected: boolean;
}; };
const numTimesLimitExceededForId: Record<string, number> = {};
export const handleNewtRegisterMessage: MessageHandler = async (context) => { export const handleNewtRegisterMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context; const { message, client, sendToClient } = context;
const newt = client as Newt; const newt = client as Newt;
@@ -90,6 +96,42 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
fetchContainers(newt.newtId); fetchContainers(newt.newtId);
} }
const rejectSiteUptime = await usageService.checkLimitSet(
oldSite.orgId,
false,
FeatureId.SITE_UPTIME
);
const rejectEgressDataMb = await usageService.checkLimitSet(
oldSite.orgId,
false,
FeatureId.EGRESS_DATA_MB
);
// Do we need to check the users and domains daily limits here?
// const rejectUsers = await usageService.checkLimitSet(oldSite.orgId, false, FeatureId.USERS);
// const rejectDomains = await usageService.checkLimitSet(oldSite.orgId, false, FeatureId.DOMAINS);
// if (rejectEgressDataMb || rejectSiteUptime || rejectUsers || rejectDomains) {
if (rejectEgressDataMb || rejectSiteUptime) {
logger.info(
`Usage limits exceeded for org ${oldSite.orgId}. Rejecting newt registration.`
);
// PREVENT FURTHER REGISTRATION ATTEMPTS SO WE DON'T SPAM
// Increment the limit exceeded count for this site
numTimesLimitExceededForId[newt.newtId] =
(numTimesLimitExceededForId[newt.newtId] || 0) + 1;
if (numTimesLimitExceededForId[newt.newtId] > 15) {
logger.debug(
`Newt ${newt.newtId} has exceeded usage limits 15 times. Terminating...`
);
}
return;
}
let siteSubnet = oldSite.subnet; let siteSubnet = oldSite.subnet;
let exitNodeIdToQuery = oldSite.exitNodeId; let exitNodeIdToQuery = oldSite.exitNodeId;
if (exitNodeId && (oldSite.exitNodeId !== exitNodeId || !oldSite.subnet)) { if (exitNodeId && (oldSite.exitNodeId !== exitNodeId || !oldSite.subnet)) {

View File

@@ -117,8 +117,6 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
return; return;
} }
const isUserDevice = olm.userId !== null && olm.userId !== undefined;
try { try {
// get the client // get the client
const [client] = await db const [client] = await db
@@ -221,9 +219,7 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
logger.error("Error handling ping message", { error }); logger.error("Error handling ping message", { error });
} }
if (isUserDevice) { await handleFingerprintInsertion(olm, fingerprint, postures);
await handleFingerprintInsertion(olm, fingerprint, postures);
}
return { return {
message: { message: {

View File

@@ -53,11 +53,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
postures postures
}); });
const isUserDevice = olm.userId !== null && olm.userId !== undefined; await handleFingerprintInsertion(olm, fingerprint, postures);
if (isUserDevice) {
await handleFingerprintInsertion(olm, fingerprint, postures);
}
if ( if (
(olmVersion && olm.version !== olmVersion) || (olmVersion && olm.version !== olmVersion) ||

View File

@@ -271,7 +271,7 @@ export async function createOrg(
// make sure we have the stripe customer // make sure we have the stripe customer
const customerId = await createCustomer(orgId, req.user?.email); const customerId = await createCustomer(orgId, req.user?.email);
if (customerId) { if (customerId) {
await usageService.updateCount( await usageService.updateDaily(
orgId, orgId,
FeatureId.USERS, FeatureId.USERS,
1, 1,

View File

@@ -10,10 +10,10 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { build } from "@server/build"; import { build } from "@server/build";
import { getOrgTierData } from "#dynamic/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { cache } from "@server/lib/cache"; import { cache } from "@server/lib/cache";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { getOrgTierData } from "#dynamic/lib/billing";
const updateOrgParamsSchema = z.strictObject({ const updateOrgParamsSchema = z.strictObject({
orgId: z.string() orgId: z.string()
@@ -88,83 +88,26 @@ export async function updateOrg(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
// Check 2FA enforcement feature const isLicensed = await isLicensedOrSubscribed(orgId);
const has2FAFeature = await isLicensedOrSubscribed( if (!isLicensed) {
orgId,
tierMatrix[TierFeature.TwoFactorEnforcement]
);
if (!has2FAFeature) {
parsedBody.data.requireTwoFactor = undefined; parsedBody.data.requireTwoFactor = undefined;
}
// Check session duration policies feature
const hasSessionDurationFeature = await isLicensedOrSubscribed(
orgId,
tierMatrix[TierFeature.SessionDurationPolicies]
);
if (!hasSessionDurationFeature) {
parsedBody.data.maxSessionLengthHours = undefined; parsedBody.data.maxSessionLengthHours = undefined;
}
// Check password expiration policies feature
const hasPasswordExpirationFeature = await isLicensedOrSubscribed(
orgId,
tierMatrix[TierFeature.PasswordExpirationPolicies]
);
if (!hasPasswordExpirationFeature) {
parsedBody.data.passwordExpiryDays = undefined; parsedBody.data.passwordExpiryDays = undefined;
} }
if (build == "saas") {
const { tier } = await getOrgTierData(orgId);
// Determine max allowed retention days based on tier const { tier } = await getOrgTierData(orgId);
let maxRetentionDays: number | null = null; if (
if (!tier) { build == "saas" &&
maxRetentionDays = 3; tier != TierId.STANDARD &&
} else if (tier === "tier1") { parsedBody.data.settingsLogRetentionDaysRequest &&
maxRetentionDays = 7; parsedBody.data.settingsLogRetentionDaysRequest > 30
} else if (tier === "tier2") { ) {
maxRetentionDays = 30; return next(
} else if (tier === "tier3") { createHttpError(
maxRetentionDays = 90; HttpCode.FORBIDDEN,
} "You are not allowed to set log retention days greater than 30 with your current subscription"
// For enterprise tier, no check (maxRetentionDays remains null) )
);
if (maxRetentionDays !== null) {
if (
parsedBody.data.settingsLogRetentionDaysRequest !== undefined &&
parsedBody.data.settingsLogRetentionDaysRequest > maxRetentionDays
) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
`You are not allowed to set log retention days greater than ${maxRetentionDays} with your current subscription`
)
);
}
if (
parsedBody.data.settingsLogRetentionDaysAccess !== undefined &&
parsedBody.data.settingsLogRetentionDaysAccess > maxRetentionDays
) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
`You are not allowed to set log retention days greater than ${maxRetentionDays} with your current subscription`
)
);
}
if (
parsedBody.data.settingsLogRetentionDaysAction !== undefined &&
parsedBody.data.settingsLogRetentionDaysAction > maxRetentionDays
) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
`You are not allowed to set log retention days greater than ${maxRetentionDays} with your current subscription`
)
);
}
}
} }
const updatedOrg = await db const updatedOrg = await db

View File

@@ -24,7 +24,6 @@ import { createCertificate } from "#dynamic/routers/certificates/createCertifica
import { validateAndConstructDomain } from "@server/lib/domainUtils"; import { validateAndConstructDomain } from "@server/lib/domainUtils";
import { build } from "@server/build"; import { build } from "@server/build";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
const updateResourceParamsSchema = z.strictObject({ const updateResourceParamsSchema = z.strictObject({
resourceId: z.string().transform(Number).pipe(z.int().positive()) resourceId: z.string().transform(Number).pipe(z.int().positive())
@@ -342,7 +341,7 @@ async function updateHttpResource(
headers = null; headers = null;
} }
const isLicensed = await isLicensedOrSubscribed(resource.orgId, tierMatrix.maintencePage); const isLicensed = await isLicensedOrSubscribed(resource.orgId);
if (!isLicensed) { if (!isLicensed) {
updateData.maintenanceModeEnabled = undefined; updateData.maintenanceModeEnabled = undefined;
updateData.maintenanceModeType = undefined; updateData.maintenanceModeType = undefined;

View File

@@ -12,7 +12,6 @@ import { eq, and } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { build } from "@server/build"; import { build } from "@server/build";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
const createRoleParamsSchema = z.strictObject({ const createRoleParamsSchema = z.strictObject({
orgId: z.string() orgId: z.string()
@@ -101,7 +100,7 @@ export async function createRole(
); );
} }
const isLicensed = await isLicensedOrSubscribed(orgId, tierMatrix.deviceApprovals); const isLicensed = await isLicensedOrSubscribed(orgId);
if (!isLicensed) { if (!isLicensed) {
roleData.requireDeviceApproval = undefined; roleData.requireDeviceApproval = undefined;
} }

View File

@@ -10,7 +10,6 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
const updateRoleParamsSchema = z.strictObject({ const updateRoleParamsSchema = z.strictObject({
roleId: z.string().transform(Number).pipe(z.int().positive()) roleId: z.string().transform(Number).pipe(z.int().positive())
@@ -111,7 +110,7 @@ export async function updateRole(
); );
} }
const isLicensed = await isLicensedOrSubscribed(orgId, tierMatrix.deviceApprovals); const isLicensed = await isLicensedOrSubscribed(orgId);
if (!isLicensed) { if (!isLicensed) {
updateData.requireDeviceApproval = undefined; updateData.requireDeviceApproval = undefined;
} }

View File

@@ -17,9 +17,6 @@ import { hashPassword } from "@server/auth/password";
import { isValidIP } from "@server/lib/validators"; import { isValidIP } from "@server/lib/validators";
import { isIpInCidr } from "@server/lib/ip"; import { isIpInCidr } from "@server/lib/ip";
import { verifyExitNodeOrgAccess } from "#dynamic/lib/exitNodes"; import { verifyExitNodeOrgAccess } from "#dynamic/lib/exitNodes";
import { build } from "@server/build";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
const createSiteParamsSchema = z.strictObject({ const createSiteParamsSchema = z.strictObject({
orgId: z.string() orgId: z.string()
@@ -128,35 +125,6 @@ export async function createSite(
); );
} }
if (build == "saas") {
const usage = await usageService.getUsage(orgId, FeatureId.SITES);
if (!usage) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"No usage data found for this organization"
)
);
}
const rejectSites = await usageService.checkLimitSet(
orgId,
FeatureId.SITES,
{
...usage,
instantaneousValue: (usage.instantaneousValue || 0) + 1
} // We need to add one to know if we are violating the limit
);
if (rejectSites) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Site limit exceeded. Please upgrade your plan."
)
);
}
}
let updatedAddress = null; let updatedAddress = null;
if (address) { if (address) {
if (!org.subnet) { if (!org.subnet) {
@@ -287,10 +255,22 @@ export async function createSite(
const niceId = await getUniqueSiteName(orgId); const niceId = await getUniqueSiteName(orgId);
let newSite: Site | undefined; let newSite: Site;
let numSites: Site[] | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
if (type == "wireguard" || type == "newt") { if (type == "newt") {
[newSite] = await trx
.insert(sites)
.values({
orgId,
name,
niceId,
address: updatedAddress || null,
type,
dockerSocketEnabled: true
})
.returning();
} else if (type == "wireguard") {
// we are creating a site with an exit node (tunneled) // we are creating a site with an exit node (tunneled)
if (!subnet) { if (!subnet) {
return next( return next(
@@ -342,11 +322,9 @@ export async function createSite(
exitNodeId, exitNodeId,
name, name,
niceId, niceId,
address: updatedAddress || null,
subnet, subnet,
type, type,
dockerSocketEnabled: type == "newt", pubKey: pubKey || null
...(pubKey && type == "wireguard" && { pubKey })
}) })
.returning(); .returning();
} else if (type == "local") { } else if (type == "local") {
@@ -433,35 +411,13 @@ export async function createSite(
}); });
} }
numSites = await trx return response<CreateSiteResponse>(res, {
.select() data: newSite,
.from(sites) success: true,
.where(eq(sites.orgId, orgId)); error: false,
}); message: "Site created successfully",
status: HttpCode.CREATED
if (numSites) { });
await usageService.updateCount(
orgId,
FeatureId.SITES,
numSites.length
);
}
if (!newSite) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create site"
)
);
}
return response<CreateSiteResponse>(res, {
data: newSite,
success: true,
error: false,
message: "Site created successfully",
status: HttpCode.CREATED
}); });
} catch (error) { } catch (error) {
logger.error(error); logger.error(error);

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, Site, siteResources } from "@server/db"; import { db, siteResources } from "@server/db";
import { newts, newtSessions, sites } from "@server/db"; import { newts, newtSessions, sites } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -12,8 +12,6 @@ import { fromError } from "zod-validation-error";
import { sendToClient } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
const deleteSiteSchema = z.strictObject({ const deleteSiteSchema = z.strictObject({
siteId: z.string().transform(Number).pipe(z.int().positive()) siteId: z.string().transform(Number).pipe(z.int().positive())
@@ -64,7 +62,6 @@ export async function deleteSite(
} }
let deletedNewtId: string | null = null; let deletedNewtId: string | null = null;
let numSites: Site[] | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
if (site.type == "wireguard") { if (site.type == "wireguard") {
@@ -102,20 +99,8 @@ export async function deleteSite(
} }
await trx.delete(sites).where(eq(sites.siteId, siteId)); await trx.delete(sites).where(eq(sites.siteId, siteId));
numSites = await trx
.select()
.from(sites)
.where(eq(sites.orgId, site.orgId));
}); });
if (numSites) {
await usageService.updateCount(
site.orgId,
FeatureId.SITES,
numSites.length
);
}
// Send termination message outside of transaction to prevent blocking // Send termination message outside of transaction to prevent blocking
if (deletedNewtId) { if (deletedNewtId) {
const payload = { const payload = {

View File

@@ -13,7 +13,6 @@ import { verifySession } from "@server/auth/sessions/verifySession";
import { usageService } from "@server/lib/billing/usageService"; import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing"; import { FeatureId } from "@server/lib/billing";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { build } from "@server/build";
const acceptInviteBodySchema = z.strictObject({ const acceptInviteBodySchema = z.strictObject({
token: z.string(), token: z.string(),
@@ -93,38 +92,6 @@ export async function acceptInvite(
); );
} }
if (build == "saas") {
const usage = await usageService.getUsage(
existingInvite.orgId,
FeatureId.USERS
);
if (!usage) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"No usage data found for this organization"
)
);
}
const rejectUsers = await usageService.checkLimitSet(
existingInvite.orgId,
FeatureId.USERS,
{
...usage,
instantaneousValue: (usage.instantaneousValue || 0) + 1
} // We need to add one to know if we are violating the limit
);
if (rejectUsers) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Can not accept because this org's user limit is exceeded. Please contact your administrator to upgrade their plan."
)
);
}
}
let roleId: number; let roleId: number;
let totalUsers: UserOrg[] | undefined; let totalUsers: UserOrg[] | undefined;
// get the role to make sure it exists // get the role to make sure it exists
@@ -158,21 +125,17 @@ export async function acceptInvite(
.delete(userInvites) .delete(userInvites)
.where(eq(userInvites.inviteId, inviteId)); .where(eq(userInvites.inviteId, inviteId));
await calculateUserClientsForOrgs(existingUser[0].userId, trx);
// Get the total number of users in the org now // Get the total number of users in the org now
totalUsers = await trx totalUsers = await db
.select() .select()
.from(userOrgs) .from(userOrgs)
.where(eq(userOrgs.orgId, existingInvite.orgId)); .where(eq(userOrgs.orgId, existingInvite.orgId));
logger.debug( await calculateUserClientsForOrgs(existingUser[0].userId, trx);
`User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}. Total users in org: ${totalUsers.length}`
);
}); });
if (totalUsers) { if (totalUsers) {
await usageService.updateCount( await usageService.updateDaily(
existingInvite.orgId, existingInvite.orgId,
FeatureId.USERS, FeatureId.USERS,
totalUsers.length totalUsers.length

View File

@@ -13,16 +13,20 @@ import { generateId } from "@server/auth/sessions/app";
import { usageService } from "@server/lib/billing/usageService"; import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing"; import { FeatureId } from "@server/lib/billing";
import { build } from "@server/build"; import { build } from "@server/build";
import { getOrgTierData } from "#dynamic/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { isSubscribed } from "#dynamic/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
const paramsSchema = z.strictObject({ const paramsSchema = z.strictObject({
orgId: z.string().nonempty() orgId: z.string().nonempty()
}); });
const bodySchema = z.strictObject({ const bodySchema = z.strictObject({
email: z.string().email().toLowerCase().optional(), email: z
.string()
.email()
.toLowerCase()
.optional(),
username: z.string().nonempty().toLowerCase(), username: z.string().nonempty().toLowerCase(),
name: z.string().optional(), name: z.string().optional(),
type: z.enum(["internal", "oidc"]).optional(), type: z.enum(["internal", "oidc"]).optional(),
@@ -91,7 +95,7 @@ export async function createOrgUser(
} }
const rejectUsers = await usageService.checkLimitSet( const rejectUsers = await usageService.checkLimitSet(
orgId, orgId,
false,
FeatureId.USERS, FeatureId.USERS,
{ {
...usage, ...usage,
@@ -128,11 +132,9 @@ export async function createOrgUser(
); );
} else if (type === "oidc") { } else if (type === "oidc") {
if (build === "saas") { if (build === "saas") {
const subscribed = await isSubscribed( const { tier } = await getOrgTierData(orgId);
orgId, const subscribed = tier === TierId.STANDARD;
tierMatrix.orgOidc if (!subscribed) {
);
if (subscribed) {
return next( return next(
createHttpError( createHttpError(
HttpCode.FORBIDDEN, HttpCode.FORBIDDEN,
@@ -254,7 +256,7 @@ export async function createOrgUser(
}); });
if (orgUsers) { if (orgUsers) {
await usageService.updateCount( await usageService.updateDaily(
orgId, orgId,
FeatureId.USERS, FeatureId.USERS,
orgUsers.length orgUsers.length

View File

@@ -133,6 +133,7 @@ export async function inviteUser(
} }
const rejectUsers = await usageService.checkLimitSet( const rejectUsers = await usageService.checkLimitSet(
orgId, orgId,
false,
FeatureId.USERS, FeatureId.USERS,
{ {
...usage, ...usage,

View File

@@ -140,7 +140,7 @@ export async function removeUserOrg(
}); });
if (userCount) { if (userCount) {
await usageService.updateCount( await usageService.updateDaily(
orgId, orgId,
FeatureId.USERS, FeatureId.USERS,
userCount.length userCount.length

View File

@@ -1,162 +0,0 @@
#! /usr/bin/env node
import { migrate } from "drizzle-orm/node-postgres/migrator";
import { db } from "../db/pg";
import semver from "semver";
import { versionMigrations } from "../db/pg";
import { __DIRNAME, APP_VERSION } from "@server/lib/consts";
import path from "path";
import m1 from "./scriptsPg/1.6.0";
import m2 from "./scriptsPg/1.7.0";
import m3 from "./scriptsPg/1.8.0";
import m4 from "./scriptsPg/1.9.0";
import m5 from "./scriptsPg/1.10.0";
import m6 from "./scriptsPg/1.10.2";
import m7 from "./scriptsPg/1.11.0";
import m8 from "./scriptsPg/1.11.1";
import m9 from "./scriptsPg/1.12.0";
import m10 from "./scriptsPg/1.13.0";
import m11 from "./scriptsPg/1.14.0";
import m12 from "./scriptsPg/1.15.0";
// THIS CANNOT IMPORT ANYTHING FROM THE SERVER
// EXCEPT FOR THE DATABASE AND THE SCHEMA
// Define the migration list with versions and their corresponding functions
const migrations = [
{ version: "1.6.0", run: m1 },
{ version: "1.7.0", run: m2 },
{ version: "1.8.0", run: m3 },
{ version: "1.9.0", run: m4 },
{ version: "1.10.0", run: m5 },
{ version: "1.10.2", run: m6 },
{ version: "1.11.0", run: m7 },
{ version: "1.11.1", run: m8 },
{ version: "1.12.0", run: m9 },
{ version: "1.13.0", run: m10 },
{ version: "1.14.0", run: m11 },
{ version: "1.15.0", run: m12 }
// Add new migrations here as they are created
] as {
version: string;
run: () => Promise<void>;
}[];
await run();
async function run() {
// run the migrations
await runMigrations();
}
export async function runMigrations() {
if (process.env.DISABLE_MIGRATIONS) {
console.log("Migrations are disabled. Skipping...");
return;
}
try {
const appVersion = APP_VERSION;
// determine if the migrations table exists
const exists = await db
.select()
.from(versionMigrations)
.limit(1)
.execute()
.then((res) => res.length > 0)
.catch(() => false);
if (exists) {
console.log("Migrations table exists, running scripts...");
await executeScripts();
} else {
console.log("Migrations table does not exist, creating it...");
console.log("Running migrations...");
try {
await migrate(db, {
migrationsFolder: path.join(__DIRNAME, "init") // put here during the docker build
});
console.log("Migrations completed successfully.");
} catch (error) {
console.error("Error running migrations:", error);
}
await db
.insert(versionMigrations)
.values({
version: appVersion,
executedAt: Date.now()
})
.execute();
}
} catch (e) {
console.error("Error running migrations:", e);
await new Promise((resolve) =>
setTimeout(resolve, 1000 * 60 * 60 * 24 * 1)
);
}
}
async function executeScripts() {
try {
// Get the last executed version from the database
const lastExecuted = await db.select().from(versionMigrations);
// Filter and sort migrations
const pendingMigrations = lastExecuted
.map((m) => m)
.sort((a, b) => semver.compare(b.version, a.version));
const startVersion = pendingMigrations[0]?.version ?? "0.0.0";
console.log(`Starting migrations from version ${startVersion}`);
const migrationsToRun = migrations.filter((migration) =>
semver.gt(migration.version, startVersion)
);
console.log(
"Migrations to run:",
migrationsToRun.map((m) => m.version).join(", ")
);
// Run migrations in order
for (const migration of migrationsToRun) {
console.log(`Running migration ${migration.version}`);
try {
await migration.run();
// Update version in database
await db
.insert(versionMigrations)
.values({
version: migration.version,
executedAt: Date.now()
})
.execute();
console.log(
`Successfully completed migration ${migration.version}`
);
} catch (e) {
if (
e instanceof Error &&
typeof (e as any).code === "string" &&
(e as any).code === "23505"
) {
console.error("Migration has already run! Skipping...");
continue; // or return, depending on context
}
console.error(
`Failed to run migration ${migration.version}:`,
e
);
throw e;
}
}
console.log("All migrations completed successfully");
} catch (error) {
console.error("Migration process failed:", error);
throw error;
}
}

View File

@@ -1 +0,0 @@
export type Tier = "tier1" | "tier2" | "tier3" | "enterprise";

View File

@@ -19,7 +19,6 @@ import OrgPolicyResult from "@app/components/OrgPolicyResult";
import UserProvider from "@app/providers/UserProvider"; import UserProvider from "@app/providers/UserProvider";
import { Layout } from "@app/components/Layout"; import { Layout } from "@app/components/Layout";
import ApplyInternalRedirect from "@app/components/ApplyInternalRedirect"; import ApplyInternalRedirect from "@app/components/ApplyInternalRedirect";
import SubscriptionViolation from "@app/components/SubscriptionViolation";
export default async function OrgLayout(props: { export default async function OrgLayout(props: {
children: React.ReactNode; children: React.ReactNode;
@@ -109,7 +108,6 @@ export default async function OrgLayout(props: {
> >
<ApplyInternalRedirect orgId={orgId} /> <ApplyInternalRedirect orgId={orgId} />
{props.children} {props.children}
{build === "saas" && <SubscriptionViolation />}
<SetLastOrgCookie orgId={orgId} /> <SetLastOrgCookie orgId={orgId} />
</SubscriptionStatusProvider> </SubscriptionStatusProvider>
); );

View File

@@ -11,7 +11,6 @@ import type { GetOrgResponse } from "@server/routers/org";
import type { ListRolesResponse } from "@server/routers/role"; import type { ListRolesResponse } from "@server/routers/role";
import type { AxiosResponse } from "axios"; import type { AxiosResponse } from "axios";
import { getTranslations } from "next-intl/server"; import { getTranslations } from "next-intl/server";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
export interface ApprovalFeedPageProps { export interface ApprovalFeedPageProps {
params: Promise<{ orgId: string }>; params: Promise<{ orgId: string }>;
@@ -30,9 +29,10 @@ export default async function ApprovalFeedPage(props: ApprovalFeedPageProps) {
// Fetch roles to check if approvals are enabled // Fetch roles to check if approvals are enabled
let hasApprovalsEnabled = false; let hasApprovalsEnabled = false;
const rolesRes = await internal const rolesRes = await internal
.get< .get<AxiosResponse<ListRolesResponse>>(
AxiosResponse<ListRolesResponse> `/org/${params.orgId}/roles`,
>(`/org/${params.orgId}/roles`, await authCookieHeader()) await authCookieHeader()
)
.catch((e) => {}); .catch((e) => {});
if (rolesRes && rolesRes.status === 200) { if (rolesRes && rolesRes.status === 200) {
@@ -52,7 +52,7 @@ export default async function ApprovalFeedPage(props: ApprovalFeedPageProps) {
<ApprovalsBanner /> <ApprovalsBanner />
<PaidFeaturesAlert tiers={tierMatrix.deviceApprovals} /> <PaidFeaturesAlert />
<OrgProvider org={org}> <OrgProvider org={org}>
<div className="container mx-auto max-w-12xl"> <div className="container mx-auto max-w-12xl">

File diff suppressed because it is too large Load Diff

View File

@@ -31,6 +31,7 @@ import { formatAxiosError } from "@app/lib/api";
import { createApiClient } from "@app/lib/api"; import { createApiClient } from "@app/lib/api";
import { useEnvContext } from "@app/hooks/useEnvContext"; import { useEnvContext } from "@app/hooks/useEnvContext";
import { useState, useEffect } from "react"; import { useState, useEffect } from "react";
import { SwitchInput } from "@app/components/SwitchInput";
import { Alert, AlertDescription, AlertTitle } from "@app/components/ui/alert"; import { Alert, AlertDescription, AlertTitle } from "@app/components/ui/alert";
import { InfoIcon, ExternalLink } from "lucide-react"; import { InfoIcon, ExternalLink } from "lucide-react";
import { import {
@@ -40,13 +41,12 @@ import {
InfoSectionTitle InfoSectionTitle
} from "@app/components/InfoSection"; } from "@app/components/InfoSection";
import CopyToClipboard from "@app/components/CopyToClipboard"; import CopyToClipboard from "@app/components/CopyToClipboard";
import { useLicenseStatusContext } from "@app/hooks/useLicenseStatusContext";
import IdpTypeBadge from "@app/components/IdpTypeBadge"; import IdpTypeBadge from "@app/components/IdpTypeBadge";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { AxiosResponse } from "axios"; import { AxiosResponse } from "axios";
import { ListRolesResponse } from "@server/routers/role"; import { ListRolesResponse } from "@server/routers/role";
import AutoProvisionConfigWidget from "@app/components/AutoProvisionConfigWidget"; import AutoProvisionConfigWidget from "@app/components/private/AutoProvisionConfigWidget";
import { PaidFeaturesAlert } from "@app/components/PaidFeaturesAlert";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
export default function GeneralPage() { export default function GeneralPage() {
const { env } = useEnvContext(); const { env } = useEnvContext();
@@ -60,6 +60,7 @@ export default function GeneralPage() {
"role" | "expression" "role" | "expression"
>("role"); >("role");
const [variant, setVariant] = useState<"oidc" | "google" | "azure">("oidc"); const [variant, setVariant] = useState<"oidc" | "google" | "azure">("oidc");
const { isUnlocked } = useLicenseStatusContext();
const dashboardRedirectUrl = `${env.app.dashboardUrl}/auth/idp/${idpId}/oidc/callback`; const dashboardRedirectUrl = `${env.app.dashboardUrl}/auth/idp/${idpId}/oidc/callback`;
const [redirectUrl, setRedirectUrl] = useState( const [redirectUrl, setRedirectUrl] = useState(
@@ -498,10 +499,6 @@ export default function GeneralPage() {
</SettingsSectionHeader> </SettingsSectionHeader>
<SettingsSectionBody> <SettingsSectionBody>
<SettingsSectionForm> <SettingsSectionForm>
<PaidFeaturesAlert
tiers={tierMatrix.autoProvisioning}
/>
<Form {...form}> <Form {...form}>
<form <form
onSubmit={form.handleSubmit(onSubmit)} onSubmit={form.handleSubmit(onSubmit)}

View File

@@ -1,7 +1,6 @@
"use client"; "use client";
import AutoProvisionConfigWidget from "@app/components/AutoProvisionConfigWidget"; import AutoProvisionConfigWidget from "@app/components/private/AutoProvisionConfigWidget";
import { PaidFeaturesAlert } from "@app/components/PaidFeaturesAlert";
import { import {
SettingsContainer, SettingsContainer,
SettingsSection, SettingsSection,
@@ -28,11 +27,9 @@ import {
import { Input } from "@app/components/ui/input"; import { Input } from "@app/components/ui/input";
import { useEnvContext } from "@app/hooks/useEnvContext"; import { useEnvContext } from "@app/hooks/useEnvContext";
import { useLicenseStatusContext } from "@app/hooks/useLicenseStatusContext"; import { useLicenseStatusContext } from "@app/hooks/useLicenseStatusContext";
import { usePaidStatus } from "@app/hooks/usePaidStatus";
import { toast } from "@app/hooks/useToast"; import { toast } from "@app/hooks/useToast";
import { createApiClient, formatAxiosError } from "@app/lib/api"; import { createApiClient, formatAxiosError } from "@app/lib/api";
import { zodResolver } from "@hookform/resolvers/zod"; import { zodResolver } from "@hookform/resolvers/zod";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { ListRolesResponse } from "@server/routers/role"; import { ListRolesResponse } from "@server/routers/role";
import { AxiosResponse } from "axios"; import { AxiosResponse } from "axios";
import { InfoIcon } from "lucide-react"; import { InfoIcon } from "lucide-react";
@@ -52,8 +49,8 @@ export default function Page() {
const [roleMappingMode, setRoleMappingMode] = useState< const [roleMappingMode, setRoleMappingMode] = useState<
"role" | "expression" "role" | "expression"
>("role"); >("role");
const { isUnlocked } = useLicenseStatusContext();
const t = useTranslations(); const t = useTranslations();
const { isPaidUser } = usePaidStatus();
const params = useParams(); const params = useParams();
@@ -364,9 +361,6 @@ export default function Page() {
</SettingsSectionHeader> </SettingsSectionHeader>
<SettingsSectionBody> <SettingsSectionBody>
<SettingsSectionForm> <SettingsSectionForm>
<PaidFeaturesAlert
tiers={tierMatrix.autoProvisioning}
/>
<Form {...form}> <Form {...form}>
<form <form
className="space-y-4" className="space-y-4"
@@ -812,7 +806,7 @@ export default function Page() {
</Button> </Button>
<Button <Button
type="submit" type="submit"
disabled={createLoading || !isPaidUser(tierMatrix.orgOidc)} disabled={createLoading}
loading={createLoading} loading={createLoading}
onClick={() => { onClick={() => {
// log any issues with the form // log any issues with the form

View File

@@ -1,8 +1,18 @@
import { pullEnv } from "@app/lib/pullEnv";
import { build } from "@server/build";
import { redirect } from "next/navigation";
interface LayoutProps { interface LayoutProps {
children: React.ReactNode; children: React.ReactNode;
params: Promise<{}>; params: Promise<{}>;
} }
export default async function Layout(props: LayoutProps) { export default async function Layout(props: LayoutProps) {
const env = pullEnv();
if (build !== "saas" && !env.flags.useOrgOnlyIdp) {
redirect("/");
}
return props.children; return props.children;
} }

View File

@@ -2,10 +2,9 @@ import { internal } from "@app/lib/api";
import { authCookieHeader } from "@app/lib/api/cookies"; import { authCookieHeader } from "@app/lib/api/cookies";
import { AxiosResponse } from "axios"; import { AxiosResponse } from "axios";
import SettingsSectionTitle from "@app/components/SettingsSectionTitle"; import SettingsSectionTitle from "@app/components/SettingsSectionTitle";
import IdpTable, { IdpRow } from "@app/components/OrgIdpTable"; import IdpTable, { IdpRow } from "@app/components/private/OrgIdpTable";
import { getTranslations } from "next-intl/server"; import { getTranslations } from "next-intl/server";
import { PaidFeaturesAlert } from "@app/components/PaidFeaturesAlert"; import { PaidFeaturesAlert } from "@app/components/PaidFeaturesAlert";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
type OrgIdpPageProps = { type OrgIdpPageProps = {
params: Promise<{ orgId: string }>; params: Promise<{ orgId: string }>;
@@ -36,7 +35,7 @@ export default async function OrgIdpPage(props: OrgIdpPageProps) {
description={t("idpManageDescription")} description={t("idpManageDescription")}
/> />
<PaidFeaturesAlert tiers={tierMatrix.orgOidc} /> <PaidFeaturesAlert />
<IdpTable idps={idps} orgId={params.orgId} /> <IdpTable idps={idps} orgId={params.orgId} />
</> </>

View File

@@ -23,6 +23,9 @@ import {
} from "@server/routers/remoteExitNode/types"; } from "@server/routers/remoteExitNode/types";
import { useRemoteExitNodeContext } from "@app/hooks/useRemoteExitNodeContext"; import { useRemoteExitNodeContext } from "@app/hooks/useRemoteExitNodeContext";
import ConfirmDeleteDialog from "@app/components/ConfirmDeleteDialog"; import ConfirmDeleteDialog from "@app/components/ConfirmDeleteDialog";
import { useSubscriptionStatusContext } from "@app/hooks/useSubscriptionStatusContext";
import { useLicenseStatusContext } from "@app/hooks/useLicenseStatusContext";
import { build } from "@server/build";
import { import {
InfoSection, InfoSection,
InfoSectionContent, InfoSectionContent,
@@ -33,8 +36,6 @@ import CopyToClipboard from "@app/components/CopyToClipboard";
import { Alert, AlertDescription, AlertTitle } from "@app/components/ui/alert"; import { Alert, AlertDescription, AlertTitle } from "@app/components/ui/alert";
import { InfoIcon } from "lucide-react"; import { InfoIcon } from "lucide-react";
import { PaidFeaturesAlert } from "@app/components/PaidFeaturesAlert"; import { PaidFeaturesAlert } from "@app/components/PaidFeaturesAlert";
import { usePaidStatus } from "@app/hooks/usePaidStatus";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
export default function CredentialsPage() { export default function CredentialsPage() {
const { env } = useEnvContext(); const { env } = useEnvContext();
@@ -44,8 +45,6 @@ export default function CredentialsPage() {
const t = useTranslations(); const t = useTranslations();
const { remoteExitNode } = useRemoteExitNodeContext(); const { remoteExitNode } = useRemoteExitNodeContext();
const { isPaidUser } = usePaidStatus();
const [modalOpen, setModalOpen] = useState(false); const [modalOpen, setModalOpen] = useState(false);
const [credentials, setCredentials] = const [credentials, setCredentials] =
useState<PickRemoteExitNodeDefaultsResponse | null>(null); useState<PickRemoteExitNodeDefaultsResponse | null>(null);
@@ -58,6 +57,16 @@ export default function CredentialsPage() {
const [showCredentialsAlert, setShowCredentialsAlert] = useState(false); const [showCredentialsAlert, setShowCredentialsAlert] = useState(false);
const [shouldDisconnect, setShouldDisconnect] = useState(true); const [shouldDisconnect, setShouldDisconnect] = useState(true);
const { licenseStatus, isUnlocked } = useLicenseStatusContext();
const subscription = useSubscriptionStatusContext();
const isSecurityFeatureDisabled = () => {
const isEnterpriseNotLicensed = build === "enterprise" && !isUnlocked();
const isSaasNotSubscribed =
build === "saas" && !subscription?.isSubscribed();
return isEnterpriseNotLicensed || isSaasNotSubscribed;
};
const handleConfirmRegenerate = async () => { const handleConfirmRegenerate = async () => {
try { try {
const response = await api.get< const response = await api.get<
@@ -129,9 +138,7 @@ export default function CredentialsPage() {
</SettingsSectionDescription> </SettingsSectionDescription>
</SettingsSectionHeader> </SettingsSectionHeader>
<SettingsSectionBody> <SettingsSectionBody>
<PaidFeaturesAlert <PaidFeaturesAlert />
tiers={tierMatrix.rotateCredentials}
/>
<InfoSections cols={3}> <InfoSections cols={3}>
<InfoSection> <InfoSection>
@@ -188,7 +195,7 @@ export default function CredentialsPage() {
</Alert> </Alert>
)} )}
</SettingsSectionBody> </SettingsSectionBody>
{!env.flags.disableEnterpriseFeatures && ( {build !== "oss" && (
<SettingsSectionFooter> <SettingsSectionFooter>
<Button <Button
variant="outline" variant="outline"
@@ -196,9 +203,7 @@ export default function CredentialsPage() {
setShouldDisconnect(false); setShouldDisconnect(false);
setModalOpen(true); setModalOpen(true);
}} }}
disabled={ disabled={isSecurityFeatureDisabled()}
!isPaidUser(tierMatrix.rotateCredentials)
}
> >
{t("regenerateCredentialsButton")} {t("regenerateCredentialsButton")}
</Button> </Button>
@@ -207,9 +212,7 @@ export default function CredentialsPage() {
setShouldDisconnect(true); setShouldDisconnect(true);
setModalOpen(true); setModalOpen(true);
}} }}
disabled={ disabled={isSecurityFeatureDisabled()}
!isPaidUser(tierMatrix.rotateCredentials)
}
> >
{t("remoteExitNodeRegenerateAndDisconnect")} {t("remoteExitNodeRegenerateAndDisconnect")}
</Button> </Button>

View File

@@ -47,8 +47,8 @@ import { ListIdpsResponse } from "@server/routers/idp";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { build } from "@server/build"; import { build } from "@server/build";
import Image from "next/image"; import Image from "next/image";
import { usePaidStatus } from "@app/hooks/usePaidStatus"; import { useSubscriptionStatusContext } from "@app/hooks/useSubscriptionStatusContext";
import { tierMatrix } from "@server/lib/billing/tierMatrix"; import { TierId } from "@server/lib/billing/tiers";
type UserType = "internal" | "oidc"; type UserType = "internal" | "oidc";
@@ -76,7 +76,7 @@ export default function Page() {
const api = createApiClient({ env }); const api = createApiClient({ env });
const t = useTranslations(); const t = useTranslations();
const { hasSaasSubscription } = usePaidStatus(); const subscription = useSubscriptionStatusContext();
const [selectedOption, setSelectedOption] = useState<string | null>( const [selectedOption, setSelectedOption] = useState<string | null>(
"internal" "internal"
@@ -238,7 +238,7 @@ export default function Page() {
} }
async function fetchIdps() { async function fetchIdps() {
if (build === "saas" && !hasSaasSubscription(tierMatrix.orgOidc)) { if (build === "saas" && !subscription?.subscribed) {
return; return;
} }

View File

@@ -19,6 +19,9 @@ import { useTranslations } from "next-intl";
import { PickClientDefaultsResponse } from "@server/routers/client"; import { PickClientDefaultsResponse } from "@server/routers/client";
import { useClientContext } from "@app/hooks/useClientContext"; import { useClientContext } from "@app/hooks/useClientContext";
import ConfirmDeleteDialog from "@app/components/ConfirmDeleteDialog"; import ConfirmDeleteDialog from "@app/components/ConfirmDeleteDialog";
import { useLicenseStatusContext } from "@app/hooks/useLicenseStatusContext";
import { useSubscriptionStatusContext } from "@app/hooks/useSubscriptionStatusContext";
import { build } from "@server/build";
import { import {
InfoSection, InfoSection,
InfoSectionContent, InfoSectionContent,
@@ -30,8 +33,6 @@ import { Alert, AlertDescription, AlertTitle } from "@app/components/ui/alert";
import { InfoIcon } from "lucide-react"; import { InfoIcon } from "lucide-react";
import { PaidFeaturesAlert } from "@app/components/PaidFeaturesAlert"; import { PaidFeaturesAlert } from "@app/components/PaidFeaturesAlert";
import { OlmInstallCommands } from "@app/components/olm-install-commands"; import { OlmInstallCommands } from "@app/components/olm-install-commands";
import { usePaidStatus } from "@app/hooks/usePaidStatus";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
export default function CredentialsPage() { export default function CredentialsPage() {
const { env } = useEnvContext(); const { env } = useEnvContext();
@@ -53,7 +54,15 @@ export default function CredentialsPage() {
const [showCredentialsAlert, setShowCredentialsAlert] = useState(false); const [showCredentialsAlert, setShowCredentialsAlert] = useState(false);
const [shouldDisconnect, setShouldDisconnect] = useState(true); const [shouldDisconnect, setShouldDisconnect] = useState(true);
const { isPaidUser } = usePaidStatus(); const { licenseStatus, isUnlocked } = useLicenseStatusContext();
const subscription = useSubscriptionStatusContext();
const isSecurityFeatureDisabled = () => {
const isEnterpriseNotLicensed = build === "enterprise" && !isUnlocked();
const isSaasNotSubscribed =
build === "saas" && !subscription?.isSubscribed();
return isEnterpriseNotLicensed || isSaasNotSubscribed;
};
const handleConfirmRegenerate = async () => { const handleConfirmRegenerate = async () => {
try { try {
@@ -119,9 +128,7 @@ export default function CredentialsPage() {
</SettingsSectionDescription> </SettingsSectionDescription>
</SettingsSectionHeader> </SettingsSectionHeader>
<SettingsSectionBody> <SettingsSectionBody>
<PaidFeaturesAlert <PaidFeaturesAlert />
tiers={tierMatrix.rotateCredentials}
/>
<InfoSections cols={3}> <InfoSections cols={3}>
<InfoSection> <InfoSection>
@@ -174,7 +181,7 @@ export default function CredentialsPage() {
</Alert> </Alert>
)} )}
</SettingsSectionBody> </SettingsSectionBody>
{!env.flags.disableEnterpriseFeatures && ( {build !== "oss" && (
<SettingsSectionFooter> <SettingsSectionFooter>
<Button <Button
variant="outline" variant="outline"
@@ -182,9 +189,7 @@ export default function CredentialsPage() {
setShouldDisconnect(false); setShouldDisconnect(false);
setModalOpen(true); setModalOpen(true);
}} }}
disabled={ disabled={isSecurityFeatureDisabled()}
!isPaidUser(tierMatrix.rotateCredentials)
}
> >
{t("regenerateCredentialsButton")} {t("regenerateCredentialsButton")}
</Button> </Button>
@@ -193,9 +198,7 @@ export default function CredentialsPage() {
setShouldDisconnect(true); setShouldDisconnect(true);
setModalOpen(true); setModalOpen(true);
}} }}
disabled={ disabled={isSecurityFeatureDisabled()}
!isPaidUser(tierMatrix.rotateCredentials)
}
> >
{t("clientRegenerateAndDisconnect")} {t("clientRegenerateAndDisconnect")}
</Button> </Button>

View File

@@ -28,19 +28,10 @@ import { createApiClient, formatAxiosError } from "@app/lib/api";
import { toast } from "@app/hooks/useToast"; import { toast } from "@app/hooks/useToast";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useState, useEffect, useTransition } from "react"; import { useState, useEffect, useTransition } from "react";
import { import { Check, Ban, Shield, ShieldOff, Clock, CheckCircle2, XCircle } from "lucide-react";
Check,
Ban,
Shield,
ShieldOff,
Clock,
CheckCircle2,
XCircle
} from "lucide-react";
import { useParams } from "next/navigation"; import { useParams } from "next/navigation";
import { FaApple, FaWindows, FaLinux } from "react-icons/fa"; import { FaApple, FaWindows, FaLinux } from "react-icons/fa";
import { SiAndroid } from "react-icons/si"; import { SiAndroid } from "react-icons/si";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
function formatTimestamp(timestamp: number | null | undefined): string { function formatTimestamp(timestamp: number | null | undefined): string {
if (!timestamp) return "-"; if (!timestamp) return "-";
@@ -120,13 +111,13 @@ function getPlatformFieldConfig(
osVersion: { show: true, labelKey: "iosVersion" }, osVersion: { show: true, labelKey: "iosVersion" },
kernelVersion: { show: false, labelKey: "kernelVersion" }, kernelVersion: { show: false, labelKey: "kernelVersion" },
arch: { show: true, labelKey: "architecture" }, arch: { show: true, labelKey: "architecture" },
deviceModel: { show: true, labelKey: "deviceModel" } deviceModel: { show: true, labelKey: "deviceModel" },
}, },
android: { android: {
osVersion: { show: true, labelKey: "androidVersion" }, osVersion: { show: true, labelKey: "androidVersion" },
kernelVersion: { show: true, labelKey: "kernelVersion" }, kernelVersion: { show: true, labelKey: "kernelVersion" },
arch: { show: true, labelKey: "architecture" }, arch: { show: true, labelKey: "architecture" },
deviceModel: { show: true, labelKey: "deviceModel" } deviceModel: { show: true, labelKey: "deviceModel" },
}, },
unknown: { unknown: {
osVersion: { show: true, labelKey: "osVersion" }, osVersion: { show: true, labelKey: "osVersion" },
@@ -142,6 +133,7 @@ function getPlatformFieldConfig(
return configs[normalizedPlatform] || configs.unknown; return configs[normalizedPlatform] || configs.unknown;
} }
export default function GeneralPage() { export default function GeneralPage() {
const { client, updateClient } = useClientContext(); const { client, updateClient } = useClientContext();
const { isPaidUser } = usePaidStatus(); const { isPaidUser } = usePaidStatus();
@@ -153,13 +145,11 @@ export default function GeneralPage() {
const [approvalId, setApprovalId] = useState<number | null>(null); const [approvalId, setApprovalId] = useState<number | null>(null);
const [isRefreshing, setIsRefreshing] = useState(false); const [isRefreshing, setIsRefreshing] = useState(false);
const [, startTransition] = useTransition(); const [, startTransition] = useTransition();
const { env } = useEnvContext();
const showApprovalFeatures = const showApprovalFeatures = build !== "oss" && isPaidUser;
build !== "oss" && isPaidUser(tierMatrix.deviceApprovals);
const formatPostureValue = (value: boolean | null | undefined | "-") => { const formatPostureValue = (value: boolean | null | undefined) => {
if (value === null || value === undefined || value === "-") return "-"; if (value === null || value === undefined) return "-";
return ( return (
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
{value ? ( {value ? (
@@ -433,8 +423,7 @@ export default function GeneralPage() {
{t( {t(
fieldConfig fieldConfig
.osVersion .osVersion
?.labelKey || ?.labelKey || "osVersion"
"osVersion"
)} )}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
@@ -570,7 +559,8 @@ export default function GeneralPage() {
</SettingsSection> </SettingsSection>
)} )}
{!env.flags.disableEnterpriseFeatures && ( {/* Device Security Section */}
{build !== "oss" && (
<SettingsSection> <SettingsSection>
<SettingsSectionHeader> <SettingsSectionHeader>
<SettingsSectionTitle> <SettingsSectionTitle>
@@ -582,27 +572,20 @@ export default function GeneralPage() {
</SettingsSectionHeader> </SettingsSectionHeader>
<SettingsSectionBody> <SettingsSectionBody>
<PaidFeaturesAlert tiers={tierMatrix.devicePosture} /> {client.posture && Object.keys(client.posture).length > 0 ? (
{client.posture &&
Object.keys(client.posture).length > 0 ? (
<> <>
{!isPaidUser && <PaidFeaturesAlert />}
<InfoSections cols={3}> <InfoSections cols={3}>
{client.posture.biometricsEnabled !== {client.posture.biometricsEnabled !== null &&
null && client.posture.biometricsEnabled !== undefined && (
client.posture.biometricsEnabled !==
undefined && (
<InfoSection> <InfoSection>
<InfoSectionTitle> <InfoSectionTitle>
{t("biometricsEnabled")} {t("biometricsEnabled")}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
{isPaidUser( {isPaidUser
tierMatrix.devicePosture
)
? formatPostureValue( ? formatPostureValue(
client.posture client.posture.biometricsEnabled
.biometricsEnabled
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -610,19 +593,15 @@ export default function GeneralPage() {
)} )}
{client.posture.diskEncrypted !== null && {client.posture.diskEncrypted !== null &&
client.posture.diskEncrypted !== client.posture.diskEncrypted !== undefined && (
undefined && (
<InfoSection> <InfoSection>
<InfoSectionTitle> <InfoSectionTitle>
{t("diskEncrypted")} {t("diskEncrypted")}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
{isPaidUser( {isPaidUser
tierMatrix.devicePosture
)
? formatPostureValue( ? formatPostureValue(
client.posture client.posture.diskEncrypted
.diskEncrypted
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -630,40 +609,31 @@ export default function GeneralPage() {
)} )}
{client.posture.firewallEnabled !== null && {client.posture.firewallEnabled !== null &&
client.posture.firewallEnabled !== client.posture.firewallEnabled !== undefined && (
undefined && (
<InfoSection> <InfoSection>
<InfoSectionTitle> <InfoSectionTitle>
{t("firewallEnabled")} {t("firewallEnabled")}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
{isPaidUser( {isPaidUser
tierMatrix.devicePosture
)
? formatPostureValue( ? formatPostureValue(
client.posture client.posture.firewallEnabled
.firewallEnabled
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
</InfoSection> </InfoSection>
)} )}
{client.posture.autoUpdatesEnabled !== {client.posture.autoUpdatesEnabled !== null &&
null && client.posture.autoUpdatesEnabled !== undefined && (
client.posture.autoUpdatesEnabled !==
undefined && (
<InfoSection> <InfoSection>
<InfoSectionTitle> <InfoSectionTitle>
{t("autoUpdatesEnabled")} {t("autoUpdatesEnabled")}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
{isPaidUser( {isPaidUser
tierMatrix.devicePosture
)
? formatPostureValue( ? formatPostureValue(
client.posture client.posture.autoUpdatesEnabled
.autoUpdatesEnabled
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -671,40 +641,29 @@ export default function GeneralPage() {
)} )}
{client.posture.tpmAvailable !== null && {client.posture.tpmAvailable !== null &&
client.posture.tpmAvailable !== client.posture.tpmAvailable !== undefined && (
undefined && (
<InfoSection> <InfoSection>
<InfoSectionTitle> <InfoSectionTitle>
{t("tpmAvailable")} {t("tpmAvailable")}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
{isPaidUser( {isPaidUser
tierMatrix.devicePosture
)
? formatPostureValue( ? formatPostureValue(
client.posture client.posture.tpmAvailable
.tpmAvailable
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
</InfoSection> </InfoSection>
)} )}
{client.posture.windowsAntivirusEnabled !== {client.posture.windowsAntivirusEnabled !== null &&
null && client.posture.windowsAntivirusEnabled !== undefined && (
client.posture
.windowsAntivirusEnabled !==
undefined && (
<InfoSection> <InfoSection>
<InfoSectionTitle> <InfoSectionTitle>
{t( {t("windowsAntivirusEnabled")}
"windowsAntivirusEnabled"
)}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
{isPaidUser( {isPaidUser
tierMatrix.devicePosture
)
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.windowsAntivirusEnabled .windowsAntivirusEnabled
@@ -715,40 +674,30 @@ export default function GeneralPage() {
)} )}
{client.posture.macosSipEnabled !== null && {client.posture.macosSipEnabled !== null &&
client.posture.macosSipEnabled !== client.posture.macosSipEnabled !== undefined && (
undefined && (
<InfoSection> <InfoSection>
<InfoSectionTitle> <InfoSectionTitle>
{t("macosSipEnabled")} {t("macosSipEnabled")}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
{isPaidUser( {isPaidUser
tierMatrix.devicePosture
)
? formatPostureValue( ? formatPostureValue(
client.posture client.posture.macosSipEnabled
.macosSipEnabled
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
</InfoSection> </InfoSection>
)} )}
{client.posture.macosGatekeeperEnabled !== {client.posture.macosGatekeeperEnabled !== null &&
null && client.posture.macosGatekeeperEnabled !==
client.posture
.macosGatekeeperEnabled !==
undefined && ( undefined && (
<InfoSection> <InfoSection>
<InfoSectionTitle> <InfoSectionTitle>
{t( {t("macosGatekeeperEnabled")}
"macosGatekeeperEnabled"
)}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
{isPaidUser( {isPaidUser
tierMatrix.devicePosture
)
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.macosGatekeeperEnabled .macosGatekeeperEnabled
@@ -758,21 +707,15 @@ export default function GeneralPage() {
</InfoSection> </InfoSection>
)} )}
{client.posture.macosFirewallStealthMode !== {client.posture.macosFirewallStealthMode !== null &&
null && client.posture.macosFirewallStealthMode !==
client.posture
.macosFirewallStealthMode !==
undefined && ( undefined && (
<InfoSection> <InfoSection>
<InfoSectionTitle> <InfoSectionTitle>
{t( {t("macosFirewallStealthMode")}
"macosFirewallStealthMode"
)}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
{isPaidUser( {isPaidUser
tierMatrix.devicePosture
)
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.macosFirewallStealthMode .macosFirewallStealthMode
@@ -782,8 +725,7 @@ export default function GeneralPage() {
</InfoSection> </InfoSection>
)} )}
{client.posture.linuxAppArmorEnabled !== {client.posture.linuxAppArmorEnabled !== null &&
null &&
client.posture.linuxAppArmorEnabled !== client.posture.linuxAppArmorEnabled !==
undefined && ( undefined && (
<InfoSection> <InfoSection>
@@ -791,9 +733,7 @@ export default function GeneralPage() {
{t("linuxAppArmorEnabled")} {t("linuxAppArmorEnabled")}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
{isPaidUser( {isPaidUser
tierMatrix.devicePosture
)
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.linuxAppArmorEnabled .linuxAppArmorEnabled
@@ -803,8 +743,7 @@ export default function GeneralPage() {
</InfoSection> </InfoSection>
)} )}
{client.posture.linuxSELinuxEnabled !== {client.posture.linuxSELinuxEnabled !== null &&
null &&
client.posture.linuxSELinuxEnabled !== client.posture.linuxSELinuxEnabled !==
undefined && ( undefined && (
<InfoSection> <InfoSection>
@@ -812,9 +751,7 @@ export default function GeneralPage() {
{t("linuxSELinuxEnabled")} {t("linuxSELinuxEnabled")}
</InfoSectionTitle> </InfoSectionTitle>
<InfoSectionContent> <InfoSectionContent>
{isPaidUser( {isPaidUser
tierMatrix.devicePosture
)
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.linuxSELinuxEnabled .linuxSELinuxEnabled

Some files were not shown because too many files have changed in this diff Show More