Merge branch 'dev'

This commit is contained in:
Owen
2025-10-15 12:13:15 -07:00
3 changed files with 219 additions and 198 deletions

View File

@@ -13,72 +13,163 @@
import config from "./config"; import config from "./config";
import { certificates, db } from "@server/db"; import { certificates, db } from "@server/db";
import { and, eq, isNotNull } from "drizzle-orm"; import { and, eq, isNotNull, or, inArray, sql } from "drizzle-orm";
import { decryptData } from "@server/lib/encryption"; import { decryptData } from "@server/lib/encryption";
import * as fs from "fs"; import * as fs from "fs";
import NodeCache from "node-cache";
const encryptionKeyPath =
config.getRawPrivateConfig().server.encryption_key_path;
if (!fs.existsSync(encryptionKeyPath)) {
throw new Error(
"Encryption key file not found. Please generate one first."
);
}
const encryptionKeyHex = fs.readFileSync(encryptionKeyPath, "utf8").trim();
const encryptionKey = Buffer.from(encryptionKeyHex, "hex");
// Define the return type for clarity and type safety
export type CertificateResult = {
id: number;
domain: string;
queriedDomain: string; // The domain that was originally requested (may differ for wildcards)
wildcard: boolean | null;
certFile: string | null;
keyFile: string | null;
expiresAt: number | null;
updatedAt?: number | null;
};
// --- In-Memory Cache Implementation ---
const certificateCache = new NodeCache({ stdTTL: 180 }); // Cache for 3 minutes (180 seconds)
export async function getValidCertificatesForDomains( export async function getValidCertificatesForDomains(
domains: Set<string> domains: Set<string>,
): Promise< useCache: boolean = true
Array<{ ): Promise<Array<CertificateResult>> {
id: number; const finalResults: CertificateResult[] = [];
domain: string; const domainsToQuery = new Set<string>();
wildcard: boolean | null;
certFile: string | null; // 1. Check cache first if enabled
keyFile: string | null; if (useCache) {
expiresAt: number | null; for (const domain of domains) {
updatedAt?: number | null; const cachedCert = certificateCache.get<CertificateResult>(domain);
}> if (cachedCert) {
> { finalResults.push(cachedCert); // Valid cache hit
if (domains.size === 0) { } else {
return []; domainsToQuery.add(domain); // Cache miss or expired
}
}
} else {
// If caching is disabled, add all domains to the query set
domains.forEach((d) => domainsToQuery.add(d));
} }
const domainArray = Array.from(domains); // 2. If all domains were resolved from the cache, return early
if (domainsToQuery.size === 0) {
return decryptFinalResults(finalResults);
}
// TODO: add more foreign keys to make this query more efficient - we dont need to keep getting every certificate // 3. Prepare domains for the database query
const validCerts = await db const domainsToQueryArray = Array.from(domainsToQuery);
.select({ const parentDomainsToQuery = new Set<string>();
id: certificates.certId,
domain: certificates.domain, domainsToQueryArray.forEach((domain) => {
certFile: certificates.certFile, const parts = domain.split(".");
keyFile: certificates.keyFile, // A wildcard can only match a domain with at least two parts (e.g., example.com)
expiresAt: certificates.expiresAt, if (parts.length > 1) {
updatedAt: certificates.updatedAt, parentDomainsToQuery.add(parts.slice(1).join("."));
wildcard: certificates.wildcard }
}) });
const parentDomainsArray = Array.from(parentDomainsToQuery);
// 4. Build and execute a single, efficient Drizzle query
// This query fetches all potential exact and wildcard matches in one database round-trip.
const potentialCerts = await db
.select()
.from(certificates) .from(certificates)
.where( .where(
and( and(
eq(certificates.status, "valid"), eq(certificates.status, "valid"),
isNotNull(certificates.certFile), isNotNull(certificates.certFile),
isNotNull(certificates.keyFile) isNotNull(certificates.keyFile),
or(
// Condition for exact matches on the requested domains
inArray(certificates.domain, domainsToQueryArray),
// Condition for wildcard matches on the parent domains
parentDomainsArray.length > 0
? and(
inArray(certificates.domain, parentDomainsArray),
eq(certificates.wildcard, true)
)
: // If there are no possible parent domains, this condition is false
sql`false`
)
) )
); );
// Filter certificates for the specified domains and if it is a wildcard then you can match on everything up to the first dot // 5. Process the database results, prioritizing exact matches over wildcards
const validCertsFiltered = validCerts.filter((cert) => { const exactMatches = new Map<string, (typeof potentialCerts)[0]>();
return ( const wildcardMatches = new Map<string, (typeof potentialCerts)[0]>();
domainArray.includes(cert.domain) ||
(cert.wildcard &&
domainArray.some((domain) =>
domain.endsWith(`.${cert.domain}`)
))
);
});
const encryptionKeyPath = config.getRawPrivateConfig().server.encryption_key_path; for (const cert of potentialCerts) {
if (cert.wildcard) {
if (!fs.existsSync(encryptionKeyPath)) { wildcardMatches.set(cert.domain, cert);
throw new Error( } else {
"Encryption key file not found. Please generate one first." exactMatches.set(cert.domain, cert);
); }
} }
const encryptionKeyHex = fs.readFileSync(encryptionKeyPath, "utf8").trim(); for (const domain of domainsToQuery) {
const encryptionKey = Buffer.from(encryptionKeyHex, "hex"); let foundCert: (typeof potentialCerts)[0] | undefined = undefined;
const validCertsDecrypted = validCertsFiltered.map((cert) => { // Priority 1: Check for an exact match
if (exactMatches.has(domain)) {
foundCert = exactMatches.get(domain);
}
// Priority 2: Check for a wildcard match on the parent domain
else {
const parts = domain.split(".");
if (parts.length > 1) {
const parentDomain = parts.slice(1).join(".");
if (wildcardMatches.has(parentDomain)) {
foundCert = wildcardMatches.get(parentDomain);
}
}
}
// If a certificate was found, format it, add to results, and cache it
if (foundCert) {
const resultCert: CertificateResult = {
id: foundCert.certId,
domain: foundCert.domain, // The actual domain of the cert record
queriedDomain: domain, // The domain that was originally requested
wildcard: foundCert.wildcard,
certFile: foundCert.certFile,
keyFile: foundCert.keyFile,
expiresAt: foundCert.expiresAt,
updatedAt: foundCert.updatedAt
};
finalResults.push(resultCert);
// Add to cache for future requests, using the *requested domain* as the key
if (useCache) {
certificateCache.set(domain, resultCert);
}
}
}
return decryptFinalResults(finalResults);
}
function decryptFinalResults(
finalResults: CertificateResult[]
): CertificateResult[] {
const validCertsDecrypted = finalResults.map((cert) => {
// Decrypt and save certificate file // Decrypt and save certificate file
const decryptedCert = decryptData( const decryptedCert = decryptData(
cert.certFile!, // is not null from query cert.certFile!, // is not null from query

View File

@@ -26,6 +26,10 @@ import { orgs, resources, sites, Target, targets } from "@server/db";
import { sanitize, validatePathRewriteConfig } from "@server/lib/traefik/utils"; import { sanitize, validatePathRewriteConfig } from "@server/lib/traefik/utils";
import privateConfig from "#private/lib/config"; import privateConfig from "#private/lib/config";
import createPathRewriteMiddleware from "@server/lib/traefik/middleware"; import createPathRewriteMiddleware from "@server/lib/traefik/middleware";
import {
CertificateResult,
getValidCertificatesForDomains
} from "#private/lib/certificates";
const redirectHttpsMiddlewareName = "redirect-to-https"; const redirectHttpsMiddlewareName = "redirect-to-https";
const redirectToRootMiddlewareName = "redirect-to-root"; const redirectToRootMiddlewareName = "redirect-to-root";
@@ -89,25 +93,11 @@ export async function getTraefikConfig(
subnet: sites.subnet, subnet: sites.subnet,
exitNodeId: sites.exitNodeId, exitNodeId: sites.exitNodeId,
// Namespace // Namespace
domainNamespaceId: domainNamespaces.domainNamespaceId, domainNamespaceId: domainNamespaces.domainNamespaceId
// Certificate fields - we'll get all valid certs and filter in application logic
certificateId: certificates.certId,
certificateDomain: certificates.domain,
certificateWildcard: certificates.wildcard,
certificateStatus: certificates.status
}) })
.from(sites) .from(sites)
.innerJoin(targets, eq(targets.siteId, sites.siteId)) .innerJoin(targets, eq(targets.siteId, sites.siteId))
.innerJoin(resources, eq(resources.resourceId, targets.resourceId)) .innerJoin(resources, eq(resources.resourceId, targets.resourceId))
.leftJoin(
certificates,
and(
eq(certificates.domainId, resources.domainId),
eq(certificates.status, "valid"),
isNotNull(certificates.certFile),
isNotNull(certificates.keyFile)
)
)
.leftJoin( .leftJoin(
targetHealthCheck, targetHealthCheck,
eq(targetHealthCheck.targetId, targets.targetId) eq(targetHealthCheck.targetId, targets.targetId)
@@ -139,14 +129,6 @@ export async function getTraefikConfig(
// Group by resource and include targets with their unique site data // Group by resource and include targets with their unique site data
const resourcesMap = new Map(); const resourcesMap = new Map();
// Track certificates per resource to determine the correct certificate status
const resourceCertificates = new Map<string, Array<{
id: number | null;
domain: string | null;
wildcard: boolean | null;
status: string | null;
}>>();
resourcesWithTargetsAndSites.forEach((row) => { resourcesWithTargetsAndSites.forEach((row) => {
const resourceId = row.resourceId; const resourceId = row.resourceId;
const resourceName = sanitize(row.resourceName) || ""; const resourceName = sanitize(row.resourceName) || "";
@@ -170,25 +152,7 @@ export async function getTraefikConfig(
.filter(Boolean) .filter(Boolean)
.join("-"); .join("-");
const mapKey = [resourceId, pathKey].filter(Boolean).join("-"); const mapKey = [resourceId, pathKey].filter(Boolean).join("-");
const key = sanitize(mapKey) || ""; const key = sanitize(mapKey);
// Track certificates for this resource
if (row.certificateId && row.certificateDomain && row.certificateStatus) {
if (!resourceCertificates.has(key)) {
resourceCertificates.set(key, []);
}
const certList = resourceCertificates.get(key)!;
// Only add if not already present (avoid duplicates from multiple targets)
if (!certList.some(cert => cert.id === row.certificateId)) {
certList.push({
id: row.certificateId,
domain: row.certificateDomain,
wildcard: row.certificateWildcard,
status: row.certificateStatus
});
}
}
if (!resourcesMap.has(key)) { if (!resourcesMap.has(key)) {
const validation = validatePathRewriteConfig( const validation = validatePathRewriteConfig(
@@ -205,26 +169,6 @@ export async function getTraefikConfig(
return; return;
} }
// Determine the correct certificate status for this resource
let certificateStatus: string | null = null;
const resourceCerts = resourceCertificates.get(key) || [];
if (row.fullDomain && resourceCerts.length > 0) {
// Find the best matching certificate
// Priority: exact domain match > wildcard match
const exactMatch = resourceCerts.find(cert =>
cert.domain === row.fullDomain
);
const wildcardMatch = resourceCerts.find(cert =>
cert.wildcard && cert.domain &&
row.fullDomain!.endsWith(`.${cert.domain}`)
);
const matchingCert = exactMatch || wildcardMatch;
certificateStatus = matchingCert?.status || null;
}
resourcesMap.set(key, { resourcesMap.set(key, {
resourceId: row.resourceId, resourceId: row.resourceId,
name: resourceName, name: resourceName,
@@ -240,7 +184,6 @@ export async function getTraefikConfig(
tlsServerName: row.tlsServerName, tlsServerName: row.tlsServerName,
setHostHeader: row.setHostHeader, setHostHeader: row.setHostHeader,
enableProxy: row.enableProxy, enableProxy: row.enableProxy,
certificateStatus: certificateStatus,
targets: [], targets: [],
headers: row.headers, headers: row.headers,
path: row.path, // the targets will all have the same path path: row.path, // the targets will all have the same path
@@ -270,6 +213,19 @@ export async function getTraefikConfig(
}); });
}); });
let validCerts: CertificateResult[] = [];
if (privateConfig.getRawPrivateConfig().flags.use_pangolin_dns) {
// create a list of all domains to get certs for
const domains = new Set<string>();
for (const resource of resourcesMap.values()) {
if (resource.enabled && resource.ssl && resource.fullDomain) {
domains.add(resource.fullDomain);
}
}
// get the valid certs for these domains
validCerts = await getValidCertificatesForDomains(domains, true); // we are caching here because this is called often
}
const config_output: any = { const config_output: any = {
http: { http: {
middlewares: { middlewares: {
@@ -312,14 +268,6 @@ export async function getTraefikConfig(
continue; continue;
} }
// TODO: for now dont filter it out because if you have multiple domain ids and one is failed it causes all of them to fail
if (resource.certificateStatus !== "valid" && privateConfig.getRawPrivateConfig().flags.use_pangolin_dns) {
logger.debug(
`Resource ${resource.resourceId} has certificate status ${resource.certificateStatus}`
);
continue;
}
// add routers and services empty objects if they don't exist // add routers and services empty objects if they don't exist
if (!config_output.http.routers) { if (!config_output.http.routers) {
config_output.http.routers = {}; config_output.http.routers = {};
@@ -329,22 +277,22 @@ export async function getTraefikConfig(
config_output.http.services = {}; config_output.http.services = {};
} }
const domainParts = fullDomain.split(".");
let wildCard;
if (domainParts.length <= 2) {
wildCard = `*.${domainParts.join(".")}`;
} else {
wildCard = `*.${domainParts.slice(1).join(".")}`;
}
if (!resource.subdomain) {
wildCard = resource.fullDomain;
}
const configDomain = config.getDomain(resource.domainId);
let tls = {}; let tls = {};
if (!privateConfig.getRawPrivateConfig().flags.use_pangolin_dns) { if (!privateConfig.getRawPrivateConfig().flags.use_pangolin_dns) {
const domainParts = fullDomain.split(".");
let wildCard;
if (domainParts.length <= 2) {
wildCard = `*.${domainParts.join(".")}`;
} else {
wildCard = `*.${domainParts.slice(1).join(".")}`;
}
if (!resource.subdomain) {
wildCard = resource.fullDomain;
}
const configDomain = config.getDomain(resource.domainId);
let certResolver: string, preferWildcardCert: boolean; let certResolver: string, preferWildcardCert: boolean;
if (!configDomain) { if (!configDomain) {
certResolver = config.getRawConfig().traefik.cert_resolver; certResolver = config.getRawConfig().traefik.cert_resolver;
@@ -367,6 +315,17 @@ export async function getTraefikConfig(
} }
: {}) : {})
}; };
} else {
// find a cert that matches the full domain, if not continue
const matchingCert = validCerts.find(
(cert) => cert.queriedDomain === resource.fullDomain
);
if (!matchingCert) {
logger.warn(
`No matching certificate found for domain: ${resource.fullDomain}`
);
continue;
}
} }
const additionalMiddlewares = const additionalMiddlewares =
@@ -733,20 +692,31 @@ export async function getTraefikConfig(
loginPageId: loginPage.loginPageId, loginPageId: loginPage.loginPageId,
fullDomain: loginPage.fullDomain, fullDomain: loginPage.fullDomain,
exitNodeId: exitNodes.exitNodeId, exitNodeId: exitNodes.exitNodeId,
domainId: loginPage.domainId, domainId: loginPage.domainId
certificateStatus: certificates.status
}) })
.from(loginPage) .from(loginPage)
.innerJoin( .innerJoin(
exitNodes, exitNodes,
eq(exitNodes.exitNodeId, loginPage.exitNodeId) eq(exitNodes.exitNodeId, loginPage.exitNodeId)
) )
.leftJoin(
certificates,
eq(certificates.domainId, loginPage.domainId)
)
.where(eq(exitNodes.exitNodeId, exitNodeId)); .where(eq(exitNodes.exitNodeId, exitNodeId));
let validCertsLoginPages: CertificateResult[] = [];
if (privateConfig.getRawPrivateConfig().flags.use_pangolin_dns) {
// create a list of all domains to get certs for
const domains = new Set<string>();
for (const lp of exitNodeLoginPages) {
if (lp.fullDomain) {
domains.add(lp.fullDomain);
}
}
// get the valid certs for these domains
validCertsLoginPages = await getValidCertificatesForDomains(
domains,
true
); // we are caching here because this is called often
}
if (exitNodeLoginPages.length > 0) { if (exitNodeLoginPages.length > 0) {
if (!config_output.http.services) { if (!config_output.http.services) {
config_output.http.services = {}; config_output.http.services = {};
@@ -776,8 +746,22 @@ export async function getTraefikConfig(
continue; continue;
} }
if (lp.certificateStatus !== "valid") { let tls = {};
continue; if (
!privateConfig.getRawPrivateConfig().flags.use_pangolin_dns
) {
// TODO: we need to add the wildcard logic here too
} else {
// find a cert that matches the full domain, if not continue
const matchingCert = validCertsLoginPages.find(
(cert) => cert.queriedDomain === lp.fullDomain
);
if (!matchingCert) {
logger.warn(
`No matching certificate found for login page domain: ${lp.fullDomain}`
);
continue;
}
} }
// auth-allowed: // auth-allowed:
@@ -800,7 +784,7 @@ export async function getTraefikConfig(
service: "landing-service", service: "landing-service",
rule: `Host(\`${fullDomain}\`) && (PathRegexp(\`^/auth/resource/[^/]+$\`) || PathRegexp(\`^/auth/idp/[0-9]+/oidc/callback\`) || PathPrefix(\`/_next\`) || Path(\`/auth/org\`) || PathRegexp(\`^/__nextjs*\`))`, rule: `Host(\`${fullDomain}\`) && (PathRegexp(\`^/auth/resource/[^/]+$\`) || PathRegexp(\`^/auth/idp/[0-9]+/oidc/callback\`) || PathPrefix(\`/_next\`) || Path(\`/auth/org\`) || PathRegexp(\`^/__nextjs*\`))`,
priority: 203, priority: 203,
tls: {} tls: tls
}; };
// auth-catchall: // auth-catchall:
@@ -819,7 +803,7 @@ export async function getTraefikConfig(
service: "landing-service", service: "landing-service",
rule: `Host(\`${fullDomain}\`)`, rule: `Host(\`${fullDomain}\`)`,
priority: 202, priority: 202,
tls: {} tls: tls
}; };
// we need to add a redirect from http to https too // we need to add a redirect from http to https too

View File

@@ -36,7 +36,7 @@ import { useTranslations } from "next-intl";
import React from "react"; import React from "react";
import { StrategySelect, StrategyOption } from "./StrategySelect"; import { StrategySelect, StrategyOption } from "./StrategySelect";
import { Alert, AlertDescription, AlertTitle } from "./ui/alert"; import { Alert, AlertDescription, AlertTitle } from "./ui/alert";
import { InfoIcon, Check, Loader2 } from "lucide-react"; import { InfoIcon, Check } from "lucide-react";
import { useUserContext } from "@app/hooks/useUserContext"; import { useUserContext } from "@app/hooks/useUserContext";
type FormProps = { type FormProps = {
@@ -67,8 +67,6 @@ export default function GenerateLicenseKeyForm({
firstName: z.string().min(1), firstName: z.string().min(1),
lastName: z.string().min(1), lastName: z.string().min(1),
primaryUse: z.string().min(1), primaryUse: z.string().min(1),
stateProvinceRegion: z.string().min(1),
postalZipCode: z.string().min(1),
country: z.string().min(1), country: z.string().min(1),
phoneNumber: z.string().optional(), phoneNumber: z.string().optional(),
agreedToTerms: z.boolean().refine((val) => val === true), agreedToTerms: z.boolean().refine((val) => val === true),
@@ -110,8 +108,6 @@ export default function GenerateLicenseKeyForm({
firstName: "", firstName: "",
lastName: "", lastName: "",
primaryUse: "", primaryUse: "",
stateProvinceRegion: "",
postalZipCode: "",
country: "", country: "",
phoneNumber: "", phoneNumber: "",
agreedToTerms: false, agreedToTerms: false,
@@ -156,8 +152,6 @@ export default function GenerateLicenseKeyForm({
firstName: "", firstName: "",
lastName: "", lastName: "",
primaryUse: "", primaryUse: "",
stateProvinceRegion: "",
postalZipCode: "",
country: "", country: "",
phoneNumber: "", phoneNumber: "",
agreedToTerms: false, agreedToTerms: false,
@@ -296,8 +290,6 @@ export default function GenerateLicenseKeyForm({
primaryUse: values.primaryUse primaryUse: values.primaryUse
}, },
personalInfo: { personalInfo: {
stateProvinceRegion: values.stateProvinceRegion,
postalZipCode: values.postalZipCode,
country: values.country, country: values.country,
phoneNumber: values.phoneNumber || "" phoneNumber: values.phoneNumber || ""
} }
@@ -516,52 +508,6 @@ export default function GenerateLicenseKeyForm({
/> />
<div className="space-y-4"> <div className="space-y-4">
<div className="grid grid-cols-2 gap-4">
<FormField
control={
personalForm.control
}
name="stateProvinceRegion"
render={({ field }) => (
<FormItem>
<FormLabel>
{t(
"generateLicenseKeyForm.form.stateProvinceRegion"
)}
</FormLabel>
<FormControl>
<Input
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={
personalForm.control
}
name="postalZipCode"
render={({ field }) => (
<FormItem>
<FormLabel>
{t(
"generateLicenseKeyForm.form.postalZipCode"
)}
</FormLabel>
<FormControl>
<Input
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
</div>
<div className="grid grid-cols-2 gap-4"> <div className="grid grid-cols-2 gap-4">
<FormField <FormField
control={ control={