diff --git a/messages/en-US.json b/messages/en-US.json index 23a973697..df252ef4a 100644 --- a/messages/en-US.json +++ b/messages/en-US.json @@ -1908,6 +1908,11 @@ "editInternalResourceDialogModePort": "Port", "editInternalResourceDialogModeHost": "Host", "editInternalResourceDialogModeCidr": "CIDR", + "editInternalResourceDialogModeHttp": "HTTP", + "editInternalResourceDialogModeHttps": "HTTPS", + "editInternalResourceDialogScheme": "Scheme", + "editInternalResourceDialogEnableSsl": "Enable SSL", + "editInternalResourceDialogEnableSslDescription": "Enable SSL/TLS encryption for secure HTTPS connections to the destination.", "editInternalResourceDialogDestination": "Destination", "editInternalResourceDialogDestinationHostDescription": "The IP address or hostname of the resource on the site's network.", "editInternalResourceDialogDestinationIPDescription": "The IP or hostname address of the resource on the site's network.", @@ -1923,6 +1928,7 @@ "createInternalResourceDialogName": "Name", "createInternalResourceDialogSite": "Site", "selectSite": "Select site...", + "multiSitesSelectorSitesCount": "{count, plural, one {# site} other {# sites}}", "noSitesFound": "No sites found.", "createInternalResourceDialogProtocol": "Protocol", "createInternalResourceDialogTcp": "TCP", @@ -1951,11 +1957,19 @@ "createInternalResourceDialogModePort": "Port", "createInternalResourceDialogModeHost": "Host", "createInternalResourceDialogModeCidr": "CIDR", + "createInternalResourceDialogModeHttp": "HTTP", + "createInternalResourceDialogModeHttps": "HTTPS", + "scheme": "Scheme", + "createInternalResourceDialogScheme": "Scheme", + "createInternalResourceDialogEnableSsl": "Enable SSL", + "createInternalResourceDialogEnableSslDescription": "Enable SSL/TLS encryption for secure HTTPS connections to the destination.", "createInternalResourceDialogDestination": "Destination", "createInternalResourceDialogDestinationHostDescription": "The IP address or hostname of the resource on the site's network.", "createInternalResourceDialogDestinationCidrDescription": "The CIDR range of the resource on the site's network.", "createInternalResourceDialogAlias": "Alias", "createInternalResourceDialogAliasDescription": "An optional internal DNS alias for this resource.", + "internalResourceDownstreamSchemeRequired": "Scheme is required for HTTP resources", + "internalResourceHttpPortRequired": "Destination port is required for HTTP resources", "siteConfiguration": "Configuration", "siteAcceptClientConnections": "Accept Client Connections", "siteAcceptClientConnectionsDescription": "Allow user devices and clients to access resources on this site. This can be changed later.", @@ -2515,6 +2529,7 @@ "validPassword": "Valid Password", "validEmail": "Valid email", "validSSO": "Valid SSO", + "connectedClient": "Connected Client", "resourceBlocked": "Resource Blocked", "droppedByRule": "Dropped by Rule", "noSessions": "No Sessions", @@ -2752,8 +2767,12 @@ "editInternalResourceDialogAddUsers": "Add Users", "editInternalResourceDialogAddClients": "Add Clients", "editInternalResourceDialogDestinationLabel": "Destination", - "editInternalResourceDialogDestinationDescription": "Specify the destination address for the internal resource. This can be a hostname, IP address, or CIDR range depending on the selected mode. Optionally set an internal DNS alias for easier identification.", + "editInternalResourceDialogDestinationDescription": "Choose where this resource runs and how clients reach it. Selecting multiple sites will create a high availability resource that can be accessed from any of the selected sites.", "editInternalResourceDialogPortRestrictionsDescription": "Restrict access to specific TCP/UDP ports or allow/block all ports.", + "createInternalResourceDialogHttpConfiguration": "HTTP configuration", + "createInternalResourceDialogHttpConfigurationDescription": "Choose the domain clients will use to reach this resource over HTTP or HTTPS.", + "editInternalResourceDialogHttpConfiguration": "HTTP configuration", + "editInternalResourceDialogHttpConfigurationDescription": "Choose the domain clients will use to reach this resource over HTTP or HTTPS.", "editInternalResourceDialogTcp": "TCP", "editInternalResourceDialogUdp": "UDP", "editInternalResourceDialogIcmp": "ICMP", @@ -2792,6 +2811,8 @@ "maintenancePageMessagePlaceholder": "We'll be back soon! Our site is currently undergoing scheduled maintenance.", "maintenancePageMessageDescription": "Detailed message explaining the maintenance", "maintenancePageTimeTitle": "Estimated Completion Time (Optional)", + "privateMaintenanceScreenTitle": "Private Placeholder Screen", + "privateMaintenanceScreenMessage": "This domain is being used on a private resource. Please connect using the Pangolin client to access this resource.", "maintenanceTime": "e.g., 2 hours, Nov 1 at 5:00 PM", "maintenanceEstimatedTimeDescription": "When you expect maintenance to be completed", "editDomain": "Edit Domain", diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index acc3bb17f..25a848f9f 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -57,7 +57,9 @@ export const orgs = pgTable("orgs", { settingsLogRetentionDaysAction: integer("settingsLogRetentionDaysAction") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year .notNull() .default(0), - settingsLogRetentionDaysConnection: integer("settingsLogRetentionDaysConnection") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year + settingsLogRetentionDaysConnection: integer( + "settingsLogRetentionDaysConnection" + ) // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year .notNull() .default(0), sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format) @@ -101,7 +103,9 @@ export const sites = pgTable("sites", { lastHolePunch: bigint("lastHolePunch", { mode: "number" }), listenPort: integer("listenPort"), dockerSocketEnabled: boolean("dockerSocketEnabled").notNull().default(true), - status: varchar("status").$type<"pending" | "approved">().default("approved") + status: varchar("status") + .$type<"pending" | "approved">() + .default("approved") }); export const resources = pgTable("resources", { @@ -222,16 +226,23 @@ export const exitNodes = pgTable("exitNodes", { export const siteResources = pgTable("siteResources", { // this is for the clients siteResourceId: serial("siteResourceId").primaryKey(), - siteId: integer("siteId") - .notNull() - .references(() => sites.siteId, { onDelete: "cascade" }), orgId: varchar("orgId") .notNull() .references(() => orgs.orgId, { onDelete: "cascade" }), + networkId: integer("networkId").references(() => networks.networkId, { + onDelete: "set null" + }), + defaultNetworkId: integer("defaultNetworkId").references( + () => networks.networkId, + { + onDelete: "restrict" + } + ), niceId: varchar("niceId").notNull(), name: varchar("name").notNull(), - mode: varchar("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port" - protocol: varchar("protocol"), // only for port mode + ssl: boolean("ssl").notNull().default(false), + mode: varchar("mode").$type<"host" | "cidr" | "http">().notNull(), // "host" | "cidr" | "http" + scheme: varchar("scheme").$type<"http" | "https">(), // only for when we are doing https or http mode proxyPort: integer("proxyPort"), // only for port mode destinationPort: integer("destinationPort"), // only for port mode destination: varchar("destination").notNull(), // ip, cidr, hostname; validate against the mode @@ -244,7 +255,38 @@ export const siteResources = pgTable("siteResources", { authDaemonPort: integer("authDaemonPort").default(22123), authDaemonMode: varchar("authDaemonMode", { length: 32 }) .$type<"site" | "remote">() - .default("site") + .default("site"), + domainId: varchar("domainId").references(() => domains.domainId, { + onDelete: "set null" + }), + subdomain: varchar("subdomain"), + fullDomain: varchar("fullDomain") +}); + +export const networks = pgTable("networks", { + networkId: serial("networkId").primaryKey(), + niceId: text("niceId"), + name: text("name"), + scope: varchar("scope") + .$type<"global" | "resource">() + .notNull() + .default("global"), + orgId: varchar("orgId") + .references(() => orgs.orgId, { + onDelete: "cascade" + }) + .notNull() +}); + +export const siteNetworks = pgTable("siteNetworks", { + siteId: integer("siteId") + .notNull() + .references(() => sites.siteId, { + onDelete: "cascade" + }), + networkId: integer("networkId") + .notNull() + .references(() => networks.networkId, { onDelete: "cascade" }) }); export const clientSiteResources = pgTable("clientSiteResources", { @@ -994,6 +1036,7 @@ export const requestAuditLog = pgTable( actor: text("actor"), actorId: text("actorId"), resourceId: integer("resourceId"), + siteResourceId: integer("siteResourceId"), ip: text("ip"), location: text("location"), userAgent: text("userAgent"), @@ -1107,3 +1150,4 @@ export type RequestAuditLog = InferSelectModel; export type RoundTripMessageTracker = InferSelectModel< typeof roundTripMessageTracker >; +export type Network = InferSelectModel; diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index 1fb04ef14..5c9d57e6d 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -54,7 +54,9 @@ export const orgs = sqliteTable("orgs", { settingsLogRetentionDaysAction: integer("settingsLogRetentionDaysAction") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year .notNull() .default(0), - settingsLogRetentionDaysConnection: integer("settingsLogRetentionDaysConnection") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year + settingsLogRetentionDaysConnection: integer( + "settingsLogRetentionDaysConnection" + ) // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year .notNull() .default(0), sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format) @@ -92,6 +94,9 @@ export const sites = sqliteTable("sites", { exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { onDelete: "set null" }), + networkId: integer("networkId").references(() => networks.networkId, { + onDelete: "set null" + }), name: text("name").notNull(), pubKey: text("pubKey"), subnet: text("subnet"), @@ -250,16 +255,21 @@ export const siteResources = sqliteTable("siteResources", { siteResourceId: integer("siteResourceId").primaryKey({ autoIncrement: true }), - siteId: integer("siteId") - .notNull() - .references(() => sites.siteId, { onDelete: "cascade" }), orgId: text("orgId") .notNull() .references(() => orgs.orgId, { onDelete: "cascade" }), + networkId: integer("networkId").references(() => networks.networkId, { + onDelete: "set null" + }), + defaultNetworkId: integer("defaultNetworkId").references( + () => networks.networkId, + { onDelete: "restrict" } + ), niceId: text("niceId").notNull(), name: text("name").notNull(), - mode: text("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port" - protocol: text("protocol"), // only for port mode + ssl: integer("ssl", { mode: "boolean" }).notNull().default(false), + mode: text("mode").$type<"host" | "cidr" | "http">().notNull(), // "host" | "cidr" | "http" + scheme: text("scheme").$type<"http" | "https">(), // only for when we are doing https or http mode proxyPort: integer("proxyPort"), // only for port mode destinationPort: integer("destinationPort"), // only for port mode destination: text("destination").notNull(), // ip, cidr, hostname @@ -274,7 +284,36 @@ export const siteResources = sqliteTable("siteResources", { authDaemonPort: integer("authDaemonPort").default(22123), authDaemonMode: text("authDaemonMode") .$type<"site" | "remote">() - .default("site") + .default("site"), + domainId: text("domainId").references(() => domains.domainId, { + onDelete: "set null" + }), + subdomain: text("subdomain"), + fullDomain: text("fullDomain"), +}); + +export const networks = sqliteTable("networks", { + networkId: integer("networkId").primaryKey({ autoIncrement: true }), + niceId: text("niceId"), + name: text("name"), + scope: text("scope") + .$type<"global" | "resource">() + .notNull() + .default("global"), + orgId: text("orgId") + .notNull() + .references(() => orgs.orgId, { onDelete: "cascade" }) +}); + +export const siteNetworks = sqliteTable("siteNetworks", { + siteId: integer("siteId") + .notNull() + .references(() => sites.siteId, { + onDelete: "cascade" + }), + networkId: integer("networkId") + .notNull() + .references(() => networks.networkId, { onDelete: "cascade" }) }); export const clientSiteResources = sqliteTable("clientSiteResources", { @@ -1096,6 +1135,7 @@ export const requestAuditLog = sqliteTable( actor: text("actor"), actorId: text("actorId"), resourceId: integer("resourceId"), + siteResourceId: integer("siteResourceId"), ip: text("ip"), location: text("location"), userAgent: text("userAgent"), @@ -1195,6 +1235,7 @@ export type ApiKey = InferSelectModel; export type ApiKeyAction = InferSelectModel; export type ApiKeyOrg = InferSelectModel; export type SiteResource = InferSelectModel; +export type Network = InferSelectModel; export type OrgDomains = InferSelectModel; export type SetupToken = InferSelectModel; export type HostMeta = InferSelectModel; diff --git a/server/index.ts b/server/index.ts index 0fc44c279..e3a6ba049 100644 --- a/server/index.ts +++ b/server/index.ts @@ -22,6 +22,7 @@ import { TraefikConfigManager } from "@server/lib/traefik/TraefikConfigManager"; import { initCleanup } from "#dynamic/cleanup"; import license from "#dynamic/license/license"; import { initLogCleanupInterval } from "@server/lib/cleanupLogs"; +import { initAcmeCertSync } from "#dynamic/lib/acmeCertSync"; import { fetchServerIp } from "@server/lib/serverIpService"; async function startServers() { @@ -39,6 +40,7 @@ async function startServers() { initTelemetryClient(); initLogCleanupInterval(); + initAcmeCertSync(); // Start all servers const apiServer = createApiServer(); diff --git a/server/lib/acmeCertSync.ts b/server/lib/acmeCertSync.ts new file mode 100644 index 000000000..d8fbd6368 --- /dev/null +++ b/server/lib/acmeCertSync.ts @@ -0,0 +1,3 @@ +export function initAcmeCertSync(): void { + // stub +} \ No newline at end of file diff --git a/server/lib/billing/tierMatrix.ts b/server/lib/billing/tierMatrix.ts index 0756ea665..d64ed1b56 100644 --- a/server/lib/billing/tierMatrix.ts +++ b/server/lib/billing/tierMatrix.ts @@ -20,6 +20,7 @@ export enum TierFeature { FullRbac = "fullRbac", SiteProvisioningKeys = "siteProvisioningKeys", // handle downgrade by revoking keys if needed SIEM = "siem", // handle downgrade by disabling SIEM integrations + HTTPPrivateResources = "httpPrivateResources", // handle downgrade by disabling HTTP private resources DomainNamespaces = "domainNamespaces" // handle downgrade by removing custom domain namespaces } @@ -58,5 +59,6 @@ export const tierMatrix: Record = { [TierFeature.FullRbac]: ["tier1", "tier2", "tier3", "enterprise"], [TierFeature.SiteProvisioningKeys]: ["tier3", "enterprise"], [TierFeature.SIEM]: ["enterprise"], + [TierFeature.HTTPPrivateResources]: ["tier3", "enterprise"], [TierFeature.DomainNamespaces]: ["tier1", "tier2", "tier3", "enterprise"] }; diff --git a/server/lib/blueprints/applyBlueprint.ts b/server/lib/blueprints/applyBlueprint.ts index a304bb392..fd189e6ca 100644 --- a/server/lib/blueprints/applyBlueprint.ts +++ b/server/lib/blueprints/applyBlueprint.ts @@ -121,8 +121,8 @@ export async function applyBlueprint({ for (const result of clientResourcesResults) { if ( result.oldSiteResource && - result.oldSiteResource.siteId != - result.newSiteResource.siteId + JSON.stringify(result.newSites?.sort()) !== + JSON.stringify(result.oldSites?.sort()) ) { // query existing associations const existingRoleIds = await trx @@ -222,38 +222,46 @@ export async function applyBlueprint({ trx ); } else { - const [newSite] = await trx - .select() - .from(sites) - .innerJoin(newts, eq(sites.siteId, newts.siteId)) - .where( - and( - eq(sites.siteId, result.newSiteResource.siteId), - eq(sites.orgId, orgId), - eq(sites.type, "newt"), - isNotNull(sites.pubKey) + let good = true; + for (const newSite of result.newSites) { + const [site] = await trx + .select() + .from(sites) + .innerJoin(newts, eq(sites.siteId, newts.siteId)) + .where( + and( + eq(sites.siteId, newSite.siteId), + eq(sites.orgId, orgId), + eq(sites.type, "newt"), + isNotNull(sites.pubKey) + ) ) - ) - .limit(1); + .limit(1); + + if (!site) { + logger.debug( + `No newt sites found for client resource ${result.newSiteResource.siteResourceId}, skipping target update` + ); + good = false; + break; + } - if (!newSite) { logger.debug( - `No newt site found for client resource ${result.newSiteResource.siteResourceId}, skipping target update` + `Updating client resource ${result.newSiteResource.siteResourceId} on site ${newSite.siteId}` ); - continue; } - logger.debug( - `Updating client resource ${result.newSiteResource.siteResourceId} on site ${newSite.sites.siteId}` - ); + if (!good) { + continue; + } await handleMessagingForUpdatedSiteResource( result.oldSiteResource, result.newSiteResource, - { - siteId: newSite.sites.siteId, - orgId: newSite.sites.orgId - }, + result.newSites.map((site) => ({ + siteId: site.siteId, + orgId: result.newSiteResource.orgId + })), trx ); } diff --git a/server/lib/blueprints/clientResources.ts b/server/lib/blueprints/clientResources.ts index 80c691c63..df1fd0cfb 100644 --- a/server/lib/blueprints/clientResources.ts +++ b/server/lib/blueprints/clientResources.ts @@ -1,24 +1,104 @@ import { clients, clientSiteResources, + domains, + orgDomains, roles, roleSiteResources, + Site, SiteResource, + siteNetworks, siteResources, Transaction, userOrgs, users, - userSiteResources + userSiteResources, + networks } from "@server/db"; import { sites } from "@server/db"; -import { eq, and, ne, inArray, or } from "drizzle-orm"; +import { eq, and, ne, inArray, or, isNotNull } from "drizzle-orm"; import { Config } from "./types"; import logger from "@server/logger"; import { getNextAvailableAliasAddress } from "../ip"; +import { createCertificate } from "#dynamic/routers/certificates/createCertificate"; + +async function getDomainForSiteResource( + siteResourceId: number | undefined, + fullDomain: string, + orgId: string, + trx: Transaction +): Promise<{ subdomain: string | null; domainId: string }> { + const [fullDomainExists] = await trx + .select({ siteResourceId: siteResources.siteResourceId }) + .from(siteResources) + .where( + and( + eq(siteResources.fullDomain, fullDomain), + eq(siteResources.orgId, orgId), + siteResourceId + ? ne(siteResources.siteResourceId, siteResourceId) + : isNotNull(siteResources.siteResourceId) + ) + ) + .limit(1); + + if (fullDomainExists) { + throw new Error( + `Site resource already exists with domain: ${fullDomain} in org ${orgId}` + ); + } + + const possibleDomains = await trx + .select() + .from(domains) + .innerJoin(orgDomains, eq(domains.domainId, orgDomains.domainId)) + .where(and(eq(orgDomains.orgId, orgId), eq(domains.verified, true))) + .execute(); + + if (possibleDomains.length === 0) { + throw new Error( + `Domain not found for full-domain: ${fullDomain} in org ${orgId}` + ); + } + + const validDomains = possibleDomains.filter((domain) => { + if (domain.domains.type == "ns" || domain.domains.type == "wildcard") { + return ( + fullDomain === domain.domains.baseDomain || + fullDomain.endsWith(`.${domain.domains.baseDomain}`) + ); + } else if (domain.domains.type == "cname") { + return fullDomain === domain.domains.baseDomain; + } + }); + + if (validDomains.length === 0) { + throw new Error( + `Domain not found for full-domain: ${fullDomain} in org ${orgId}` + ); + } + + const domainSelection = validDomains[0].domains; + const baseDomain = domainSelection.baseDomain; + + let subdomain: string | null = null; + if (fullDomain !== baseDomain) { + subdomain = fullDomain.replace(`.${baseDomain}`, ""); + } + + await createCertificate(domainSelection.domainId, fullDomain, trx); + + return { + subdomain, + domainId: domainSelection.domainId + }; +} export type ClientResourcesResults = { newSiteResource: SiteResource; oldSiteResource?: SiteResource; + newSites: { siteId: number }[]; + oldSites: { siteId: number }[]; }[]; export async function updateClientResources( @@ -43,53 +123,104 @@ export async function updateClientResources( ) .limit(1); - const resourceSiteId = resourceData.site; - let site; + const existingSiteIds = existingResource?.networkId + ? await trx + .select({ siteId: sites.siteId }) + .from(siteNetworks) + .where(eq(siteNetworks.networkId, existingResource.networkId)) + : []; - if (resourceSiteId) { - // Look up site by niceId - [site] = await trx - .select({ siteId: sites.siteId }) - .from(sites) - .where( - and( - eq(sites.niceId, resourceSiteId), - eq(sites.orgId, orgId) + let allSites: { siteId: number }[] = []; + if (resourceData.site) { + let siteSingle; + const resourceSiteId = resourceData.site; + + if (resourceSiteId) { + // Look up site by niceId + [siteSingle] = await trx + .select({ siteId: sites.siteId }) + .from(sites) + .where( + and( + eq(sites.niceId, resourceSiteId), + eq(sites.orgId, orgId) + ) ) - ) - .limit(1); - } else if (siteId) { - // Use the provided siteId directly, but verify it belongs to the org - [site] = await trx - .select({ siteId: sites.siteId }) - .from(sites) - .where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId))) - .limit(1); - } else { - throw new Error(`Target site is required`); + .limit(1); + } else if (siteId) { + // Use the provided siteId directly, but verify it belongs to the org + [siteSingle] = await trx + .select({ siteId: sites.siteId }) + .from(sites) + .where( + and(eq(sites.siteId, siteId), eq(sites.orgId, orgId)) + ) + .limit(1); + } else { + throw new Error(`Target site is required`); + } + + if (!siteSingle) { + throw new Error( + `Site not found: ${resourceSiteId} in org ${orgId}` + ); + } + allSites.push(siteSingle); } - if (!site) { - throw new Error( - `Site not found: ${resourceSiteId} in org ${orgId}` - ); + if (resourceData.sites) { + for (const siteNiceId of resourceData.sites) { + const [site] = await trx + .select({ siteId: sites.siteId }) + .from(sites) + .where( + and( + eq(sites.niceId, siteNiceId), + eq(sites.orgId, orgId) + ) + ) + .limit(1); + if (!site) { + throw new Error( + `Site not found: ${siteId} in org ${orgId}` + ); + } + allSites.push(site); + } } if (existingResource) { + let domainInfo: + | { subdomain: string | null; domainId: string } + | undefined; + if (resourceData["full-domain"] && resourceData.mode === "http") { + domainInfo = await getDomainForSiteResource( + existingResource.siteResourceId, + resourceData["full-domain"], + orgId, + trx + ); + } + // Update existing resource const [updatedResource] = await trx .update(siteResources) .set({ name: resourceData.name || resourceNiceId, - siteId: site.siteId, mode: resourceData.mode, + ssl: resourceData.ssl, + scheme: resourceData.scheme, destination: resourceData.destination, + destinationPort: resourceData["destination-port"], enabled: true, // hardcoded for now // enabled: resourceData.enabled ?? true, alias: resourceData.alias || null, disableIcmp: resourceData["disable-icmp"], tcpPortRangeString: resourceData["tcp-ports"], - udpPortRangeString: resourceData["udp-ports"] + udpPortRangeString: resourceData["udp-ports"], + fullDomain: resourceData["full-domain"] || null, + subdomain: domainInfo ? domainInfo.subdomain : null, + domainId: domainInfo ? domainInfo.domainId : null }) .where( eq( @@ -100,7 +231,21 @@ export async function updateClientResources( .returning(); const siteResourceId = existingResource.siteResourceId; - const orgId = existingResource.orgId; + + if (updatedResource.networkId) { + await trx + .delete(siteNetworks) + .where( + eq(siteNetworks.networkId, updatedResource.networkId) + ); + + for (const site of allSites) { + await trx.insert(siteNetworks).values({ + siteId: site.siteId, + networkId: updatedResource.networkId + }); + } + } await trx .delete(clientSiteResources) @@ -204,37 +349,72 @@ export async function updateClientResources( results.push({ newSiteResource: updatedResource, - oldSiteResource: existingResource + oldSiteResource: existingResource, + newSites: allSites, + oldSites: existingSiteIds }); } else { let aliasAddress: string | null = null; - if (resourceData.mode == "host") { - // we can only have an alias on a host + if (resourceData.mode === "host" || resourceData.mode === "http") { aliasAddress = await getNextAvailableAliasAddress(orgId); } + let domainInfo: + | { subdomain: string | null; domainId: string } + | undefined; + if (resourceData["full-domain"] && resourceData.mode === "http") { + domainInfo = await getDomainForSiteResource( + undefined, + resourceData["full-domain"], + orgId, + trx + ); + } + + const [network] = await trx + .insert(networks) + .values({ + scope: "resource", + orgId: orgId + }) + .returning(); + // Create new resource const [newResource] = await trx .insert(siteResources) .values({ orgId: orgId, - siteId: site.siteId, niceId: resourceNiceId, + networkId: network.networkId, + defaultNetworkId: network.networkId, name: resourceData.name || resourceNiceId, mode: resourceData.mode, + ssl: resourceData.ssl, + scheme: resourceData.scheme, destination: resourceData.destination, + destinationPort: resourceData["destination-port"], enabled: true, // hardcoded for now // enabled: resourceData.enabled ?? true, alias: resourceData.alias || null, aliasAddress: aliasAddress, disableIcmp: resourceData["disable-icmp"], tcpPortRangeString: resourceData["tcp-ports"], - udpPortRangeString: resourceData["udp-ports"] + udpPortRangeString: resourceData["udp-ports"], + fullDomain: resourceData["full-domain"] || null, + subdomain: domainInfo ? domainInfo.subdomain : null, + domainId: domainInfo ? domainInfo.domainId : null }) .returning(); const siteResourceId = newResource.siteResourceId; + for (const site of allSites) { + await trx.insert(siteNetworks).values({ + siteId: site.siteId, + networkId: network.networkId + }); + } + const [adminRole] = await trx .select() .from(roles) @@ -324,7 +504,11 @@ export async function updateClientResources( `Created new client resource ${newResource.name} (${newResource.siteResourceId}) for org ${orgId}` ); - results.push({ newSiteResource: newResource }); + results.push({ + newSiteResource: newResource, + newSites: allSites, + oldSites: existingSiteIds + }); } } diff --git a/server/lib/blueprints/proxyResources.ts b/server/lib/blueprints/proxyResources.ts index e16da2ea5..4d78e946d 100644 --- a/server/lib/blueprints/proxyResources.ts +++ b/server/lib/blueprints/proxyResources.ts @@ -1100,7 +1100,7 @@ function checkIfTargetChanged( return false; } -async function getDomain( +export async function getDomain( resourceId: number | undefined, fullDomain: string, orgId: string, diff --git a/server/lib/blueprints/types.ts b/server/lib/blueprints/types.ts index 6ebc509b8..8269e7b65 100644 --- a/server/lib/blueprints/types.ts +++ b/server/lib/blueprints/types.ts @@ -164,6 +164,7 @@ export const ResourceSchema = z name: z.string().optional(), protocol: z.enum(["http", "tcp", "udp"]).optional(), ssl: z.boolean().optional(), + scheme: z.enum(["http", "https"]).optional(), "full-domain": z.string().optional(), "proxy-port": z.int().min(1).max(65535).optional(), enabled: z.boolean().optional(), @@ -325,16 +326,20 @@ export function isTargetsOnlyResource(resource: any): boolean { export const ClientResourceSchema = z .object({ name: z.string().min(1).max(255), - mode: z.enum(["host", "cidr"]), - site: z.string(), + mode: z.enum(["host", "cidr", "http"]), + site: z.string(), // DEPRECATED IN FAVOR OF sites + sites: z.array(z.string()).optional().default([]), // protocol: z.enum(["tcp", "udp"]).optional(), // proxyPort: z.int().positive().optional(), - // destinationPort: z.int().positive().optional(), + "destination-port": z.int().positive().optional(), destination: z.string().min(1), // enabled: z.boolean().default(true), "tcp-ports": portRangeStringSchema.optional().default("*"), "udp-ports": portRangeStringSchema.optional().default("*"), "disable-icmp": z.boolean().optional().default(false), + "full-domain": z.string().optional(), + ssl: z.boolean().optional(), + scheme: z.enum(["http", "https"]).optional().nullable(), alias: z .string() .regex( @@ -477,6 +482,39 @@ export const ConfigSchema = z }); } + // Enforce the full-domain uniqueness across client-resources in the same stack + const clientFullDomainMap = new Map(); + + Object.entries(config["client-resources"]).forEach( + ([resourceKey, resource]) => { + const fullDomain = resource["full-domain"]; + if (fullDomain) { + if (!clientFullDomainMap.has(fullDomain)) { + clientFullDomainMap.set(fullDomain, []); + } + clientFullDomainMap.get(fullDomain)!.push(resourceKey); + } + } + ); + + const clientFullDomainDuplicates = Array.from( + clientFullDomainMap.entries() + ) + .filter(([_, resourceKeys]) => resourceKeys.length > 1) + .map( + ([fullDomain, resourceKeys]) => + `'${fullDomain}' used by resources: ${resourceKeys.join(", ")}` + ) + .join("; "); + + if (clientFullDomainDuplicates.length !== 0) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + path: ["client-resources"], + message: `Duplicate 'full-domain' values found: ${clientFullDomainDuplicates}` + }); + } + // Enforce proxy-port uniqueness within proxy-resources per protocol const protocolPortMap = new Map(); diff --git a/server/lib/encryption.ts b/server/lib/encryption.ts deleted file mode 100644 index 79caecd1a..000000000 --- a/server/lib/encryption.ts +++ /dev/null @@ -1,39 +0,0 @@ -import crypto from "crypto"; - -export function encryptData(data: string, key: Buffer): string { - const algorithm = "aes-256-gcm"; - const iv = crypto.randomBytes(16); - const cipher = crypto.createCipheriv(algorithm, key, iv); - - let encrypted = cipher.update(data, "utf8", "hex"); - encrypted += cipher.final("hex"); - - const authTag = cipher.getAuthTag(); - - // Combine IV, auth tag, and encrypted data - return iv.toString("hex") + ":" + authTag.toString("hex") + ":" + encrypted; -} - -// Helper function to decrypt data (you'll need this to read certificates) -export function decryptData(encryptedData: string, key: Buffer): string { - const algorithm = "aes-256-gcm"; - const parts = encryptedData.split(":"); - - if (parts.length !== 3) { - throw new Error("Invalid encrypted data format"); - } - - const iv = Buffer.from(parts[0], "hex"); - const authTag = Buffer.from(parts[1], "hex"); - const encrypted = parts[2]; - - const decipher = crypto.createDecipheriv(algorithm, key, iv); - decipher.setAuthTag(authTag); - - let decrypted = decipher.update(encrypted, "hex", "utf8"); - decrypted += decipher.final("utf8"); - - return decrypted; -} - -// openssl rand -hex 32 > config/encryption.key diff --git a/server/lib/ip.ts b/server/lib/ip.ts index 633983629..3e57e8c94 100644 --- a/server/lib/ip.ts +++ b/server/lib/ip.ts @@ -5,6 +5,7 @@ import config from "@server/lib/config"; import z from "zod"; import logger from "@server/logger"; import semver from "semver"; +import { getValidCertificatesForDomains } from "#dynamic/lib/certificates"; interface IPRange { start: bigint; @@ -477,9 +478,9 @@ export type Alias = { alias: string | null; aliasAddress: string | null }; export function generateAliasConfig(allSiteResources: SiteResource[]): Alias[] { return allSiteResources - .filter((sr) => sr.alias && sr.aliasAddress && sr.mode == "host") + .filter((sr) => sr.aliasAddress && ((sr.alias && sr.mode == "host") || (sr.fullDomain && sr.mode == "http"))) .map((sr) => ({ - alias: sr.alias, + alias: sr.alias || sr.fullDomain, aliasAddress: sr.aliasAddress })); } @@ -582,16 +583,26 @@ export type SubnetProxyTargetV2 = { protocol: "tcp" | "udp"; }[]; resourceId?: number; + protocol?: "http" | "https"; // if set, this target only applies to the specified protocol + httpTargets?: HTTPTarget[]; + tlsCert?: string; + tlsKey?: string; }; -export function generateSubnetProxyTargetV2( +export type HTTPTarget = { + destAddr: string; // must be an IP or hostname + destPort: number; + scheme: "http" | "https"; +}; + +export async function generateSubnetProxyTargetV2( siteResource: SiteResource, clients: { clientId: number; pubKey: string | null; subnet: string | null; }[] -): SubnetProxyTargetV2[] | undefined { +): Promise { if (clients.length === 0) { logger.debug( `No clients have access to site resource ${siteResource.siteResourceId}, skipping target generation.` @@ -642,6 +653,67 @@ export function generateSubnetProxyTargetV2( disableIcmp, resourceId: siteResource.siteResourceId }); + } else if (siteResource.mode == "http") { + let destination = siteResource.destination; + // check if this is a valid ip + const ipSchema = z.union([z.ipv4(), z.ipv6()]); + if (ipSchema.safeParse(destination).success) { + destination = `${destination}/32`; + } + + if ( + !siteResource.aliasAddress || + !siteResource.destinationPort || + !siteResource.scheme || + !siteResource.fullDomain + ) { + logger.debug( + `Site resource ${siteResource.siteResourceId} is in HTTP mode but is missing alias or alias address or destinationPort or scheme, skipping alias target generation.` + ); + return; + } + // also push a match for the alias address + let tlsCert: string | undefined; + let tlsKey: string | undefined; + + if (siteResource.ssl && siteResource.fullDomain) { + try { + const certs = await getValidCertificatesForDomains( + new Set([siteResource.fullDomain]), + true + ); + if (certs.length > 0 && certs[0].certFile && certs[0].keyFile) { + tlsCert = certs[0].certFile; + tlsKey = certs[0].keyFile; + } else { + logger.warn( + `No valid certificate found for SSL site resource ${siteResource.siteResourceId} with domain ${siteResource.fullDomain}` + ); + } + } catch (err) { + logger.error( + `Failed to retrieve certificate for site resource ${siteResource.siteResourceId} domain ${siteResource.fullDomain}: ${err}` + ); + } + } + + targets.push({ + sourcePrefixes: [], + destPrefix: `${siteResource.aliasAddress}/32`, + rewriteTo: destination, + portRange, + disableIcmp, + resourceId: siteResource.siteResourceId, + protocol: siteResource.ssl ? "https" : "http", + httpTargets: [ + { + destAddr: siteResource.destination, + destPort: siteResource.destinationPort, + scheme: siteResource.scheme + } + ], + ...(tlsCert && tlsKey ? { tlsCert, tlsKey } : {}) + }); } if (targets.length == 0) { diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index d636a2f2b..b213ec9be 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -11,17 +11,16 @@ import { roleSiteResources, Site, SiteResource, + siteNetworks, siteResources, sites, Transaction, userOrgRoles, - userOrgs, userSiteResources } from "@server/db"; import { and, eq, inArray, ne } from "drizzle-orm"; import { - addPeer as newtAddPeer, deletePeer as newtDeletePeer } from "@server/routers/newt/peers"; import { @@ -35,7 +34,6 @@ import { generateRemoteSubnets, generateSubnetProxyTargetV2, parseEndpoint, - formatEndpoint } from "@server/lib/ip"; import { addPeerData, @@ -48,15 +46,27 @@ export async function getClientSiteResourceAccess( siteResource: SiteResource, trx: Transaction | typeof db = db ) { - // get the site - const [site] = await trx - .select() - .from(sites) - .where(eq(sites.siteId, siteResource.siteId)) - .limit(1); + // get all sites associated with this siteResource via its network + const sitesList = siteResource.networkId + ? await trx + .select() + .from(sites) + .innerJoin( + siteNetworks, + eq(siteNetworks.siteId, sites.siteId) + ) + .where(eq(siteNetworks.networkId, siteResource.networkId)) + .then((rows) => rows.map((row) => row.sites)) + : []; - if (!site) { - throw new Error(`Site with ID ${siteResource.siteId} not found`); + logger.debug( + `rebuildClientAssociations: [getClientSiteResourceAccess] siteResourceId=${siteResource.siteResourceId} networkId=${siteResource.networkId} siteCount=${sitesList.length} siteIds=[${sitesList.map((s) => s.siteId).join(", ")}]` + ); + + if (sitesList.length === 0) { + logger.warn( + `No sites found for siteResource ${siteResource.siteResourceId} with networkId ${siteResource.networkId}` + ); } const roleIds = await trx @@ -136,8 +146,12 @@ export async function getClientSiteResourceAccess( const mergedAllClients = Array.from(allClientsMap.values()); const mergedAllClientIds = mergedAllClients.map((c) => c.clientId); + logger.debug( + `rebuildClientAssociations: [getClientSiteResourceAccess] siteResourceId=${siteResource.siteResourceId} mergedClientCount=${mergedAllClientIds.length} clientIds=[${mergedAllClientIds.join(", ")}] (userBased=${newAllClients.length} direct=${directClients.length})` + ); + return { - site, + sitesList, mergedAllClients, mergedAllClientIds }; @@ -153,40 +167,59 @@ export async function rebuildClientAssociationsFromSiteResource( subnet: string | null; }[]; }> { - const siteId = siteResource.siteId; + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] START siteResourceId=${siteResource.siteResourceId} networkId=${siteResource.networkId} orgId=${siteResource.orgId}` + ); - const { site, mergedAllClients, mergedAllClientIds } = + const { sitesList, mergedAllClients, mergedAllClientIds } = await getClientSiteResourceAccess(siteResource, trx); + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] access resolved siteResourceId=${siteResource.siteResourceId} siteCount=${sitesList.length} siteIds=[${sitesList.map((s) => s.siteId).join(", ")}] mergedClientCount=${mergedAllClients.length} clientIds=[${mergedAllClientIds.join(", ")}]` + ); + /////////// process the client-siteResource associations /////////// - // get all of the clients associated with other resources on this site - const allUpdatedClientsFromOtherResourcesOnThisSite = await trx - .select({ - clientId: clientSiteResourcesAssociationsCache.clientId - }) - .from(clientSiteResourcesAssociationsCache) - .innerJoin( - siteResources, - eq( - clientSiteResourcesAssociationsCache.siteResourceId, - siteResources.siteResourceId - ) - ) - .where( - and( - eq(siteResources.siteId, siteId), - ne(siteResources.siteResourceId, siteResource.siteResourceId) - ) - ); + // get all of the clients associated with other resources in the same network, + // joined through siteNetworks so we know which siteId each client belongs to + const allUpdatedClientsFromOtherResourcesOnThisSite = siteResource.networkId + ? await trx + .select({ + clientId: clientSiteResourcesAssociationsCache.clientId, + siteId: siteNetworks.siteId + }) + .from(clientSiteResourcesAssociationsCache) + .innerJoin( + siteResources, + eq( + clientSiteResourcesAssociationsCache.siteResourceId, + siteResources.siteResourceId + ) + ) + .innerJoin( + siteNetworks, + eq(siteNetworks.networkId, siteResources.networkId) + ) + .where( + and( + eq(siteResources.networkId, siteResource.networkId), + ne( + siteResources.siteResourceId, + siteResource.siteResourceId + ) + ) + ) + : []; - const allClientIdsFromOtherResourcesOnThisSite = Array.from( - new Set( - allUpdatedClientsFromOtherResourcesOnThisSite.map( - (row) => row.clientId - ) - ) - ); + // Build a per-site map so the loop below can check by siteId rather than + // across the entire network. + const clientsFromOtherResourcesBySite = new Map>(); + for (const row of allUpdatedClientsFromOtherResourcesOnThisSite) { + if (!clientsFromOtherResourcesBySite.has(row.siteId)) { + clientsFromOtherResourcesBySite.set(row.siteId, new Set()); + } + clientsFromOtherResourcesBySite.get(row.siteId)!.add(row.clientId); + } const existingClientSiteResources = await trx .select({ @@ -204,6 +237,10 @@ export async function rebuildClientAssociationsFromSiteResource( (row) => row.clientId ); + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteResourceId=${siteResource.siteResourceId} existingResourceClientIds=[${existingClientSiteResourceIds.join(", ")}]` + ); + // Get full client details for existing resource clients (needed for sending delete messages) const existingResourceClients = existingClientSiteResourceIds.length > 0 @@ -223,6 +260,10 @@ export async function rebuildClientAssociationsFromSiteResource( (clientId) => !existingClientSiteResourceIds.includes(clientId) ); + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteResourceId=${siteResource.siteResourceId} resourceClients toAdd=[${clientSiteResourcesToAdd.join(", ")}]` + ); + const clientSiteResourcesToInsert = clientSiteResourcesToAdd.map( (clientId) => ({ clientId, @@ -231,17 +272,34 @@ export async function rebuildClientAssociationsFromSiteResource( ); if (clientSiteResourcesToInsert.length > 0) { + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteResourceId=${siteResource.siteResourceId} inserting ${clientSiteResourcesToInsert.length} clientSiteResource association(s)` + ); await trx .insert(clientSiteResourcesAssociationsCache) .values(clientSiteResourcesToInsert) .returning(); + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteResourceId=${siteResource.siteResourceId} inserted clientSiteResource associations` + ); + } else { + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteResourceId=${siteResource.siteResourceId} no clientSiteResource associations to insert` + ); } const clientSiteResourcesToRemove = existingClientSiteResourceIds.filter( (clientId) => !mergedAllClientIds.includes(clientId) ); + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteResourceId=${siteResource.siteResourceId} resourceClients toRemove=[${clientSiteResourcesToRemove.join(", ")}]` + ); + if (clientSiteResourcesToRemove.length > 0) { + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteResourceId=${siteResource.siteResourceId} deleting ${clientSiteResourcesToRemove.length} clientSiteResource association(s)` + ); await trx .delete(clientSiteResourcesAssociationsCache) .where( @@ -260,82 +318,127 @@ export async function rebuildClientAssociationsFromSiteResource( /////////// process the client-site associations /////////// - const existingClientSites = await trx - .select({ - clientId: clientSitesAssociationsCache.clientId - }) - .from(clientSitesAssociationsCache) - .where(eq(clientSitesAssociationsCache.siteId, siteResource.siteId)); - - const existingClientSiteIds = existingClientSites.map( - (row) => row.clientId + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteResourceId=${siteResource.siteResourceId} beginning client-site association loop over ${sitesList.length} site(s)` ); - // Get full client details for existing clients (needed for sending delete messages) - const existingClients = await trx - .select({ - clientId: clients.clientId, - pubKey: clients.pubKey, - subnet: clients.subnet - }) - .from(clients) - .where(inArray(clients.clientId, existingClientSiteIds)); + for (const site of sitesList) { + const siteId = site.siteId; - const clientSitesToAdd = mergedAllClientIds.filter( - (clientId) => - !existingClientSiteIds.includes(clientId) && - !allClientIdsFromOtherResourcesOnThisSite.includes(clientId) // dont remove if there is still another connection for another site resource - ); + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] processing siteId=${siteId} for siteResourceId=${siteResource.siteResourceId}` + ); - const clientSitesToInsert = clientSitesToAdd.map((clientId) => ({ - clientId, - siteId - })); + const existingClientSites = await trx + .select({ + clientId: clientSitesAssociationsCache.clientId + }) + .from(clientSitesAssociationsCache) + .where(eq(clientSitesAssociationsCache.siteId, siteId)); - if (clientSitesToInsert.length > 0) { - await trx - .insert(clientSitesAssociationsCache) - .values(clientSitesToInsert) - .returning(); - } + const existingClientSiteIds = existingClientSites.map( + (row) => row.clientId + ); - // Now remove any client-site associations that should no longer exist - const clientSitesToRemove = existingClientSiteIds.filter( - (clientId) => - !mergedAllClientIds.includes(clientId) && - !allClientIdsFromOtherResourcesOnThisSite.includes(clientId) // dont remove if there is still another connection for another site resource - ); + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteId=${siteId} existingClientSiteIds=[${existingClientSiteIds.join(", ")}]` + ); - if (clientSitesToRemove.length > 0) { - await trx - .delete(clientSitesAssociationsCache) - .where( - and( - eq(clientSitesAssociationsCache.siteId, siteId), - inArray( - clientSitesAssociationsCache.clientId, - clientSitesToRemove - ) - ) + // Get full client details for existing clients (needed for sending delete messages) + const existingClients = + existingClientSiteIds.length > 0 + ? await trx + .select({ + clientId: clients.clientId, + pubKey: clients.pubKey, + subnet: clients.subnet + }) + .from(clients) + .where(inArray(clients.clientId, existingClientSiteIds)) + : []; + + const otherResourceClientIds = clientsFromOtherResourcesBySite.get(siteId) ?? new Set(); + + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteId=${siteId} otherResourceClientIds=[${[...otherResourceClientIds].join(", ")}] mergedAllClientIds=[${mergedAllClientIds.join(", ")}]` + ); + + const clientSitesToAdd = mergedAllClientIds.filter( + (clientId) => + !existingClientSiteIds.includes(clientId) && + !otherResourceClientIds.has(clientId) // dont add if already connected via another site resource + ); + + const clientSitesToInsert = clientSitesToAdd.map((clientId) => ({ + clientId, + siteId + })); + + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteId=${siteId} clientSites toAdd=[${clientSitesToAdd.join(", ")}]` + ); + + if (clientSitesToInsert.length > 0) { + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteId=${siteId} inserting ${clientSitesToInsert.length} clientSite association(s)` ); + await trx + .insert(clientSitesAssociationsCache) + .values(clientSitesToInsert) + .returning(); + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteId=${siteId} inserted clientSite associations` + ); + } else { + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteId=${siteId} no clientSite associations to insert` + ); + } + + // Now remove any client-site associations that should no longer exist + const clientSitesToRemove = existingClientSiteIds.filter( + (clientId) => + !mergedAllClientIds.includes(clientId) && + !otherResourceClientIds.has(clientId) // dont remove if there is still another connection for another site resource + ); + + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteId=${siteId} clientSites toRemove=[${clientSitesToRemove.join(", ")}]` + ); + + if (clientSitesToRemove.length > 0) { + logger.debug( + `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteId=${siteId} deleting ${clientSitesToRemove.length} clientSite association(s)` + ); + await trx + .delete(clientSitesAssociationsCache) + .where( + and( + eq(clientSitesAssociationsCache.siteId, siteId), + inArray( + clientSitesAssociationsCache.clientId, + clientSitesToRemove + ) + ) + ); + } + + // Now handle the messages to add/remove peers on both the newt and olm sides + await handleMessagesForSiteClients( + site, + siteId, + mergedAllClients, + existingClients, + clientSitesToAdd, + clientSitesToRemove, + trx + ); } - /////////// send the messages /////////// - - // Now handle the messages to add/remove peers on both the newt and olm sides - await handleMessagesForSiteClients( - site, - siteId, - mergedAllClients, - existingClients, - clientSitesToAdd, - clientSitesToRemove, - trx - ); - // Handle subnet proxy target updates for the resource associations await handleSubnetProxyTargetUpdates( siteResource, + sitesList, mergedAllClients, existingResourceClients, clientSiteResourcesToAdd, @@ -624,6 +727,7 @@ export async function updateClientSiteDestinations( async function handleSubnetProxyTargetUpdates( siteResource: SiteResource, + sitesList: Site[], allClients: { clientId: number; pubKey: string | null; @@ -638,125 +742,138 @@ async function handleSubnetProxyTargetUpdates( clientSiteResourcesToRemove: number[], trx: Transaction | typeof db = db ): Promise { - // Get the newt for this site - const [newt] = await trx - .select() - .from(newts) - .where(eq(newts.siteId, siteResource.siteId)) - .limit(1); + const proxyJobs: Promise[] = []; + const olmJobs: Promise[] = []; - if (!newt) { - logger.warn( - `Newt not found for site ${siteResource.siteId}, skipping subnet proxy target updates` - ); - return; - } + for (const siteData of sitesList) { + const siteId = siteData.siteId; - const proxyJobs = []; - const olmJobs = []; - // Generate targets for added associations - if (clientSiteResourcesToAdd.length > 0) { - const addedClients = allClients.filter((client) => - clientSiteResourcesToAdd.includes(client.clientId) - ); + // Get the newt for this site + const [newt] = await trx + .select() + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); - if (addedClients.length > 0) { - const targetsToAdd = generateSubnetProxyTargetV2( - siteResource, - addedClients + if (!newt) { + logger.warn( + `Newt not found for site ${siteId}, skipping subnet proxy target updates` ); - - if (targetsToAdd) { - proxyJobs.push( - addSubnetProxyTargets( - newt.newtId, - targetsToAdd, - newt.version - ) - ); - } - - for (const client of addedClients) { - olmJobs.push( - addPeerData( - client.clientId, - siteResource.siteId, - generateRemoteSubnets([siteResource]), - generateAliasConfig([siteResource]) - ) - ); - } + continue; } - } - // here we use the existingSiteResource from BEFORE we updated the destination so we dont need to worry about updating destinations here - - // Generate targets for removed associations - if (clientSiteResourcesToRemove.length > 0) { - const removedClients = existingClients.filter((client) => - clientSiteResourcesToRemove.includes(client.clientId) - ); - - if (removedClients.length > 0) { - const targetsToRemove = generateSubnetProxyTargetV2( - siteResource, - removedClients + // Generate targets for added associations + if (clientSiteResourcesToAdd.length > 0) { + const addedClients = allClients.filter((client) => + clientSiteResourcesToAdd.includes(client.clientId) ); - if (targetsToRemove) { - proxyJobs.push( - removeSubnetProxyTargets( - newt.newtId, - targetsToRemove, - newt.version - ) + if (addedClients.length > 0) { + const targetsToAdd = await generateSubnetProxyTargetV2( + siteResource, + addedClients ); - } - for (const client of removedClients) { - // Check if this client still has access to another resource on this site with the same destination - const destinationStillInUse = await trx - .select() - .from(siteResources) - .innerJoin( - clientSiteResourcesAssociationsCache, - eq( - clientSiteResourcesAssociationsCache.siteResourceId, - siteResources.siteResourceId - ) - ) - .where( - and( - eq( - clientSiteResourcesAssociationsCache.clientId, - client.clientId - ), - eq(siteResources.siteId, siteResource.siteId), - eq( - siteResources.destination, - siteResource.destination - ), - ne( - siteResources.siteResourceId, - siteResource.siteResourceId - ) + if (targetsToAdd) { + proxyJobs.push( + addSubnetProxyTargets( + newt.newtId, + targetsToAdd, + newt.version ) ); + } - // Only remove remote subnet if no other resource uses the same destination - const remoteSubnetsToRemove = - destinationStillInUse.length > 0 - ? [] - : generateRemoteSubnets([siteResource]); + for (const client of addedClients) { + olmJobs.push( + addPeerData( + client.clientId, + siteId, + generateRemoteSubnets([siteResource]), + generateAliasConfig([siteResource]) + ) + ); + } + } + } - olmJobs.push( - removePeerData( - client.clientId, - siteResource.siteId, - remoteSubnetsToRemove, - generateAliasConfig([siteResource]) - ) + // here we use the existingSiteResource from BEFORE we updated the destination so we dont need to worry about updating destinations here + + // Generate targets for removed associations + if (clientSiteResourcesToRemove.length > 0) { + const removedClients = existingClients.filter((client) => + clientSiteResourcesToRemove.includes(client.clientId) + ); + + if (removedClients.length > 0) { + const targetsToRemove = await generateSubnetProxyTargetV2( + siteResource, + removedClients ); + + if (targetsToRemove) { + proxyJobs.push( + removeSubnetProxyTargets( + newt.newtId, + targetsToRemove, + newt.version + ) + ); + } + + for (const client of removedClients) { + // Check if this client still has access to another resource + // on this specific site with the same destination. We scope + // by siteId (via siteNetworks) rather than networkId because + // removePeerData operates per-site — a resource on a different + // site sharing the same network should not block removal here. + const destinationStillInUse = await trx + .select() + .from(siteResources) + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + clientSiteResourcesAssociationsCache.siteResourceId, + siteResources.siteResourceId + ) + ) + .innerJoin( + siteNetworks, + eq(siteNetworks.networkId, siteResources.networkId) + ) + .where( + and( + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ), + eq(siteNetworks.siteId, siteId), + eq( + siteResources.destination, + siteResource.destination + ), + ne( + siteResources.siteResourceId, + siteResource.siteResourceId + ) + ) + ); + + // Only remove remote subnet if no other resource uses the same destination + const remoteSubnetsToRemove = + destinationStillInUse.length > 0 + ? [] + : generateRemoteSubnets([siteResource]); + + olmJobs.push( + removePeerData( + client.clientId, + siteId, + remoteSubnetsToRemove, + generateAliasConfig([siteResource]) + ) + ); + } } } } @@ -863,10 +980,25 @@ export async function rebuildClientAssociationsFromClient( ) : []; - // Group by siteId for site-level associations - const newSiteIds = Array.from( - new Set(newSiteResources.map((sr) => sr.siteId)) + // Group by siteId for site-level associations — look up via siteNetworks since + // siteResources no longer carries a direct siteId column. + const networkIds = Array.from( + new Set( + newSiteResources + .map((sr) => sr.networkId) + .filter((id): id is number => id !== null) + ) ); + const newSiteIds = + networkIds.length > 0 + ? await trx + .select({ siteId: siteNetworks.siteId }) + .from(siteNetworks) + .where(inArray(siteNetworks.networkId, networkIds)) + .then((rows) => + Array.from(new Set(rows.map((r) => r.siteId))) + ) + : []; /////////// Process client-siteResource associations /////////// @@ -1139,13 +1271,45 @@ async function handleMessagesForClientResources( resourcesToAdd.includes(r.siteResourceId) ); + // Build (resource, siteId) pairs by looking up siteNetworks for each resource's networkId + const addedNetworkIds = Array.from( + new Set( + addedResources + .map((r) => r.networkId) + .filter((id): id is number => id !== null) + ) + ); + const addedSiteNetworkRows = + addedNetworkIds.length > 0 + ? await trx + .select({ + networkId: siteNetworks.networkId, + siteId: siteNetworks.siteId + }) + .from(siteNetworks) + .where(inArray(siteNetworks.networkId, addedNetworkIds)) + : []; + const addedNetworkToSites = new Map(); + for (const row of addedSiteNetworkRows) { + if (!addedNetworkToSites.has(row.networkId)) { + addedNetworkToSites.set(row.networkId, []); + } + addedNetworkToSites.get(row.networkId)!.push(row.siteId); + } + // Group by site for proxy updates const addedBySite = new Map(); for (const resource of addedResources) { - if (!addedBySite.has(resource.siteId)) { - addedBySite.set(resource.siteId, []); + const siteIds = + resource.networkId != null + ? (addedNetworkToSites.get(resource.networkId) ?? []) + : []; + for (const siteId of siteIds) { + if (!addedBySite.has(siteId)) { + addedBySite.set(siteId, []); + } + addedBySite.get(siteId)!.push(resource); } - addedBySite.get(resource.siteId)!.push(resource); } // Add subnet proxy targets for each site @@ -1164,7 +1328,7 @@ async function handleMessagesForClientResources( } for (const resource of resources) { - const targets = generateSubnetProxyTargetV2(resource, [ + const targets = await generateSubnetProxyTargetV2(resource, [ { clientId: client.clientId, pubKey: client.pubKey, @@ -1187,7 +1351,7 @@ async function handleMessagesForClientResources( olmJobs.push( addPeerData( client.clientId, - resource.siteId, + siteId, generateRemoteSubnets([resource]), generateAliasConfig([resource]) ) @@ -1199,7 +1363,7 @@ async function handleMessagesForClientResources( error.message.includes("not found") ) { logger.debug( - `Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal` + `Olm data not found for client ${client.clientId} and site ${siteId}, skipping addition` ); } else { throw error; @@ -1216,13 +1380,45 @@ async function handleMessagesForClientResources( .from(siteResources) .where(inArray(siteResources.siteResourceId, resourcesToRemove)); + // Build (resource, siteId) pairs via siteNetworks + const removedNetworkIds = Array.from( + new Set( + removedResources + .map((r) => r.networkId) + .filter((id): id is number => id !== null) + ) + ); + const removedSiteNetworkRows = + removedNetworkIds.length > 0 + ? await trx + .select({ + networkId: siteNetworks.networkId, + siteId: siteNetworks.siteId + }) + .from(siteNetworks) + .where(inArray(siteNetworks.networkId, removedNetworkIds)) + : []; + const removedNetworkToSites = new Map(); + for (const row of removedSiteNetworkRows) { + if (!removedNetworkToSites.has(row.networkId)) { + removedNetworkToSites.set(row.networkId, []); + } + removedNetworkToSites.get(row.networkId)!.push(row.siteId); + } + // Group by site for proxy updates const removedBySite = new Map(); for (const resource of removedResources) { - if (!removedBySite.has(resource.siteId)) { - removedBySite.set(resource.siteId, []); + const siteIds = + resource.networkId != null + ? (removedNetworkToSites.get(resource.networkId) ?? []) + : []; + for (const siteId of siteIds) { + if (!removedBySite.has(siteId)) { + removedBySite.set(siteId, []); + } + removedBySite.get(siteId)!.push(resource); } - removedBySite.get(resource.siteId)!.push(resource); } // Remove subnet proxy targets for each site @@ -1241,7 +1437,7 @@ async function handleMessagesForClientResources( } for (const resource of resources) { - const targets = generateSubnetProxyTargetV2(resource, [ + const targets = await generateSubnetProxyTargetV2(resource, [ { clientId: client.clientId, pubKey: client.pubKey, @@ -1260,7 +1456,11 @@ async function handleMessagesForClientResources( } try { - // Check if this client still has access to another resource on this site with the same destination + // Check if this client still has access to another resource + // on this specific site with the same destination. We scope + // by siteId (via siteNetworks) rather than networkId because + // removePeerData operates per-site — a resource on a different + // site sharing the same network should not block removal here. const destinationStillInUse = await trx .select() .from(siteResources) @@ -1271,13 +1471,17 @@ async function handleMessagesForClientResources( siteResources.siteResourceId ) ) + .innerJoin( + siteNetworks, + eq(siteNetworks.networkId, siteResources.networkId) + ) .where( and( eq( clientSiteResourcesAssociationsCache.clientId, client.clientId ), - eq(siteResources.siteId, resource.siteId), + eq(siteNetworks.siteId, siteId), eq( siteResources.destination, resource.destination @@ -1299,7 +1503,7 @@ async function handleMessagesForClientResources( olmJobs.push( removePeerData( client.clientId, - resource.siteId, + siteId, remoteSubnetsToRemove, generateAliasConfig([resource]) ) @@ -1311,7 +1515,7 @@ async function handleMessagesForClientResources( error.message.includes("not found") ) { logger.debug( - `Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal` + `Olm data not found for client ${client.clientId} and site ${siteId}, skipping removal` ); } else { throw error; diff --git a/server/private/lib/acmeCertSync.ts b/server/private/lib/acmeCertSync.ts new file mode 100644 index 000000000..1fb609c28 --- /dev/null +++ b/server/private/lib/acmeCertSync.ts @@ -0,0 +1,478 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + +import fs from "fs"; +import crypto from "crypto"; +import { + certificates, + clients, + clientSiteResourcesAssociationsCache, + db, + domains, + newts, + siteNetworks, + SiteResource, + siteResources +} from "@server/db"; +import { and, eq } from "drizzle-orm"; +import { encrypt, decrypt } from "@server/lib/crypto"; +import logger from "@server/logger"; +import privateConfig from "#private/lib/config"; +import config from "@server/lib/config"; +import { + generateSubnetProxyTargetV2, + SubnetProxyTargetV2 +} from "@server/lib/ip"; +import { updateTargets } from "@server/routers/client/targets"; +import cache from "#private/lib/cache"; +import { build } from "@server/build"; + +interface AcmeCert { + domain: { main: string; sans?: string[] }; + certificate: string; + key: string; + Store: string; +} + +interface AcmeJson { + [resolver: string]: { + Certificates: AcmeCert[]; + }; +} + +async function pushCertUpdateToAffectedNewts( + domain: string, + domainId: string | null, + oldCertPem: string | null, + oldKeyPem: string | null +): Promise { + // Find all SSL-enabled HTTP site resources that use this cert's domain + let affectedResources: SiteResource[] = []; + + if (domainId) { + affectedResources = await db + .select() + .from(siteResources) + .where( + and( + eq(siteResources.domainId, domainId), + eq(siteResources.ssl, true) + ) + ); + } else { + // Fallback: match by exact fullDomain when no domainId is available + affectedResources = await db + .select() + .from(siteResources) + .where( + and( + eq(siteResources.fullDomain, domain), + eq(siteResources.ssl, true) + ) + ); + } + + if (affectedResources.length === 0) { + logger.debug( + `acmeCertSync: no affected site resources for cert domain "${domain}"` + ); + return; + } + + logger.info( + `acmeCertSync: pushing cert update to ${affectedResources.length} affected site resource(s) for domain "${domain}"` + ); + + for (const resource of affectedResources) { + try { + // Get all sites for this resource via siteNetworks + const resourceSiteRows = resource.networkId + ? await db + .select({ siteId: siteNetworks.siteId }) + .from(siteNetworks) + .where(eq(siteNetworks.networkId, resource.networkId)) + : []; + + if (resourceSiteRows.length === 0) { + logger.debug( + `acmeCertSync: no sites for resource ${resource.siteResourceId}, skipping` + ); + continue; + } + + // Get all clients with access to this resource + const resourceClients = await db + .select({ + clientId: clients.clientId, + pubKey: clients.pubKey, + subnet: clients.subnet + }) + .from(clients) + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + clients.clientId, + clientSiteResourcesAssociationsCache.clientId + ) + ) + .where( + eq( + clientSiteResourcesAssociationsCache.siteResourceId, + resource.siteResourceId + ) + ); + + if (resourceClients.length === 0) { + logger.debug( + `acmeCertSync: no clients for resource ${resource.siteResourceId}, skipping` + ); + continue; + } + + // Invalidate the cert cache so generateSubnetProxyTargetV2 fetches fresh data + if (resource.fullDomain) { + await cache.del(`cert:${resource.fullDomain}`); + } + + // Generate target once — same cert applies to all sites for this resource + const newTargets = await generateSubnetProxyTargetV2( + resource, + resourceClients + ); + + if (!newTargets) { + logger.debug( + `acmeCertSync: could not generate target for resource ${resource.siteResourceId}, skipping` + ); + continue; + } + + // Construct the old targets — same routing shape but with the previous cert/key. + // The newt only uses destPrefix/sourcePrefixes for removal, but we keep the + // semantics correct so the update message accurately reflects what changed. + const oldTargets: SubnetProxyTargetV2[] = newTargets.map((t) => ({ + ...t, + tlsCert: oldCertPem ?? undefined, + tlsKey: oldKeyPem ?? undefined + })); + + // Push update to each site's newt + for (const { siteId } of resourceSiteRows) { + const [newt] = await db + .select() + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); + + if (!newt) { + logger.debug( + `acmeCertSync: no newt found for site ${siteId}, skipping resource ${resource.siteResourceId}` + ); + continue; + } + + await updateTargets( + newt.newtId, + { oldTargets: oldTargets, newTargets: newTargets }, + newt.version + ); + + logger.info( + `acmeCertSync: pushed cert update to newt for site ${siteId}, resource ${resource.siteResourceId}` + ); + } + } catch (err) { + logger.error( + `acmeCertSync: error pushing cert update for resource ${resource?.siteResourceId}: ${err}` + ); + } + } +} + +async function findDomainId(certDomain: string): Promise { + // Strip wildcard prefix before lookup (*.example.com -> example.com) + const lookupDomain = certDomain.startsWith("*.") + ? certDomain.slice(2) + : certDomain; + + // 1. Exact baseDomain match (any domain type) + const exactMatch = await db + .select({ domainId: domains.domainId }) + .from(domains) + .where(eq(domains.baseDomain, lookupDomain)) + .limit(1); + + if (exactMatch.length > 0) { + return exactMatch[0].domainId; + } + + // 2. Walk up the domain hierarchy looking for a wildcard-type domain whose + // baseDomain is a suffix of the cert domain. e.g. cert "sub.example.com" + // matches a wildcard domain with baseDomain "example.com". + const parts = lookupDomain.split("."); + for (let i = 1; i < parts.length; i++) { + const candidate = parts.slice(i).join("."); + if (!candidate) continue; + + const wildcardMatch = await db + .select({ domainId: domains.domainId }) + .from(domains) + .where( + and( + eq(domains.baseDomain, candidate), + eq(domains.type, "wildcard") + ) + ) + .limit(1); + + if (wildcardMatch.length > 0) { + return wildcardMatch[0].domainId; + } + } + + return null; +} + +function extractFirstCert(pemBundle: string): string | null { + const match = pemBundle.match( + /-----BEGIN CERTIFICATE-----[\s\S]+?-----END CERTIFICATE-----/ + ); + return match ? match[0] : null; +} + +async function syncAcmeCerts( + acmeJsonPath: string, + resolver: string +): Promise { + let raw: string; + try { + raw = fs.readFileSync(acmeJsonPath, "utf8"); + } catch (err) { + logger.debug(`acmeCertSync: could not read ${acmeJsonPath}: ${err}`); + return; + } + + let acmeJson: AcmeJson; + try { + acmeJson = JSON.parse(raw); + } catch (err) { + logger.debug(`acmeCertSync: could not parse acme.json: ${err}`); + return; + } + + const resolverData = acmeJson[resolver]; + if (!resolverData || !Array.isArray(resolverData.Certificates)) { + logger.debug( + `acmeCertSync: no certificates found for resolver "${resolver}"` + ); + return; + } + + for (const cert of resolverData.Certificates) { + const domain = cert.domain?.main; + + if (!domain) { + logger.debug(`acmeCertSync: skipping cert with missing domain`); + continue; + } + + if (!cert.certificate || !cert.key) { + logger.debug( + `acmeCertSync: skipping cert for ${domain} - empty certificate or key field` + ); + continue; + } + + const certPem = Buffer.from(cert.certificate, "base64").toString( + "utf8" + ); + const keyPem = Buffer.from(cert.key, "base64").toString("utf8"); + + if (!certPem.trim() || !keyPem.trim()) { + logger.debug( + `acmeCertSync: skipping cert for ${domain} - blank PEM after base64 decode` + ); + continue; + } + + // Check if cert already exists in DB + const existing = await db + .select() + .from(certificates) + .where(eq(certificates.domain, domain)) + .limit(1); + + let oldCertPem: string | null = null; + let oldKeyPem: string | null = null; + + if (existing.length > 0 && existing[0].certFile) { + try { + const storedCertPem = decrypt( + existing[0].certFile, + config.getRawConfig().server.secret! + ); + if (storedCertPem === certPem) { + logger.debug( + `acmeCertSync: cert for ${domain} is unchanged, skipping` + ); + continue; + } + // Cert has changed; capture old values so we can send a correct + // update message to the newt after the DB write. + oldCertPem = storedCertPem; + if (existing[0].keyFile) { + try { + oldKeyPem = decrypt( + existing[0].keyFile, + config.getRawConfig().server.secret! + ); + } catch (keyErr) { + logger.debug( + `acmeCertSync: could not decrypt stored key for ${domain}: ${keyErr}` + ); + } + } + } catch (err) { + // Decryption failure means we should proceed with the update + logger.debug( + `acmeCertSync: could not decrypt stored cert for ${domain}, will update: ${err}` + ); + } + } + + // Parse cert expiry from the first cert in the PEM bundle + let expiresAt: number | null = null; + const firstCertPem = extractFirstCert(certPem); + if (firstCertPem) { + try { + const x509 = new crypto.X509Certificate(firstCertPem); + expiresAt = Math.floor(new Date(x509.validTo).getTime() / 1000); + } catch (err) { + logger.debug( + `acmeCertSync: could not parse cert expiry for ${domain}: ${err}` + ); + } + } + + const wildcard = domain.startsWith("*."); + const encryptedCert = encrypt( + certPem, + config.getRawConfig().server.secret! + ); + const encryptedKey = encrypt( + keyPem, + config.getRawConfig().server.secret! + ); + const now = Math.floor(Date.now() / 1000); + + const domainId = await findDomainId(domain); + if (domainId) { + logger.debug( + `acmeCertSync: resolved domainId "${domainId}" for cert domain "${domain}"` + ); + } else { + logger.debug( + `acmeCertSync: no matching domain record found for cert domain "${domain}"` + ); + } + + if (existing.length > 0) { + await db + .update(certificates) + .set({ + certFile: encryptedCert, + keyFile: encryptedKey, + status: "valid", + expiresAt, + updatedAt: now, + wildcard, + ...(domainId !== null && { domainId }) + }) + .where(eq(certificates.domain, domain)); + + logger.info( + `acmeCertSync: updated certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})` + ); + + await pushCertUpdateToAffectedNewts( + domain, + domainId, + oldCertPem, + oldKeyPem + ); + } else { + await db.insert(certificates).values({ + domain, + domainId, + certFile: encryptedCert, + keyFile: encryptedKey, + status: "valid", + expiresAt, + createdAt: now, + updatedAt: now, + wildcard + }); + + logger.info( + `acmeCertSync: inserted new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})` + ); + + // For a brand-new cert, push to any SSL resources that were waiting for it + await pushCertUpdateToAffectedNewts(domain, domainId, null, null); + } + } +} + +export function initAcmeCertSync(): void { + if (build == "saas") { + logger.debug(`acmeCertSync: skipping ACME cert sync in SaaS build`); + return; + } + + const privateConfigData = privateConfig.getRawPrivateConfig(); + + if (!privateConfigData.flags?.enable_acme_cert_sync) { + logger.debug( + `acmeCertSync: ACME cert sync is disabled by config flag, skipping` + ); + return; + } + + if (privateConfigData.flags.use_pangolin_dns) { + logger.debug( + `acmeCertSync: ACME cert sync requires use_pangolin_dns flag to be disabled, skipping` + ); + return; + } + + const acmeJsonPath = + privateConfigData.acme?.acme_json_path ?? + "config/letsencrypt/acme.json"; + const resolver = privateConfigData.acme?.resolver ?? "letsencrypt"; + const intervalMs = privateConfigData.acme?.sync_interval_ms ?? 5000; + + logger.info( + `acmeCertSync: starting ACME cert sync from "${acmeJsonPath}" using resolver "${resolver}" every ${intervalMs}ms` + ); + + // Run immediately on init, then on the configured interval + syncAcmeCerts(acmeJsonPath, resolver).catch((err) => { + logger.error(`acmeCertSync: error during initial sync: ${err}`); + }); + + setInterval(() => { + syncAcmeCerts(acmeJsonPath, resolver).catch((err) => { + logger.error(`acmeCertSync: error during sync: ${err}`); + }); + }, intervalMs); +} diff --git a/server/private/lib/certificates.ts b/server/private/lib/certificates.ts index ae076c48e..af6f6fdaa 100644 --- a/server/private/lib/certificates.ts +++ b/server/private/lib/certificates.ts @@ -11,23 +11,15 @@ * This file is not licensed under the AGPLv3. */ -import config from "./config"; +import privateConfig from "./config"; +import config from "@server/lib/config"; import { certificates, db } from "@server/db"; import { and, eq, isNotNull, or, inArray, sql } from "drizzle-orm"; -import { decryptData } from "@server/lib/encryption"; +import { decrypt } from "@server/lib/crypto"; import logger from "@server/logger"; import cache from "#private/lib/cache"; -let encryptionKeyHex = ""; -let encryptionKey: Buffer; -function loadEncryptData() { - if (encryptionKey) { - return; // already loaded - } - encryptionKeyHex = config.getRawPrivateConfig().server.encryption_key; - encryptionKey = Buffer.from(encryptionKeyHex, "hex"); -} // Define the return type for clarity and type safety export type CertificateResult = { @@ -45,7 +37,7 @@ export async function getValidCertificatesForDomains( domains: Set, useCache: boolean = true ): Promise> { - loadEncryptData(); // Ensure encryption key is loaded + const finalResults: CertificateResult[] = []; const domainsToQuery = new Set(); @@ -68,7 +60,7 @@ export async function getValidCertificatesForDomains( // 2. If all domains were resolved from the cache, return early if (domainsToQuery.size === 0) { - const decryptedResults = decryptFinalResults(finalResults); + const decryptedResults = decryptFinalResults(finalResults, config.getRawConfig().server.secret!); return decryptedResults; } @@ -173,22 +165,23 @@ export async function getValidCertificatesForDomains( } } - const decryptedResults = decryptFinalResults(finalResults); + const decryptedResults = decryptFinalResults(finalResults, config.getRawConfig().server.secret!); return decryptedResults; } function decryptFinalResults( - finalResults: CertificateResult[] + finalResults: CertificateResult[], + secret: string ): CertificateResult[] { const validCertsDecrypted = finalResults.map((cert) => { // Decrypt and save certificate file - const decryptedCert = decryptData( + const decryptedCert = decrypt( cert.certFile!, // is not null from query - encryptionKey + secret ); // Decrypt and save key file - const decryptedKey = decryptData(cert.keyFile!, encryptionKey); + const decryptedKey = decrypt(cert.keyFile!, secret); // Return only the certificate data without org information return { diff --git a/server/private/lib/readConfigFile.ts b/server/private/lib/readConfigFile.ts index f239edd85..c9cb1535a 100644 --- a/server/private/lib/readConfigFile.ts +++ b/server/private/lib/readConfigFile.ts @@ -34,10 +34,6 @@ export const privateConfigSchema = z.object({ }), server: z .object({ - encryption_key: z - .string() - .optional() - .transform(getEnvOrYaml("SERVER_ENCRYPTION_KEY")), reo_client_id: z .string() .optional() @@ -95,10 +91,21 @@ export const privateConfigSchema = z.object({ .object({ enable_redis: z.boolean().optional().default(false), use_pangolin_dns: z.boolean().optional().default(false), - use_org_only_idp: z.boolean().optional() + use_org_only_idp: z.boolean().optional(), + enable_acme_cert_sync: z.boolean().optional().default(true) }) .optional() .prefault({}), + acme: z + .object({ + acme_json_path: z + .string() + .optional() + .default("config/letsencrypt/acme.json"), + resolver: z.string().optional().default("letsencrypt"), + sync_interval_ms: z.number().optional().default(5000) + }) + .optional(), branding: z .object({ app_name: z.string().optional(), diff --git a/server/private/lib/traefik/getTraefikConfig.ts b/server/private/lib/traefik/getTraefikConfig.ts index 5ab96d6d6..404615a86 100644 --- a/server/private/lib/traefik/getTraefikConfig.ts +++ b/server/private/lib/traefik/getTraefikConfig.ts @@ -33,7 +33,7 @@ import { } from "drizzle-orm"; import logger from "@server/logger"; import config from "@server/lib/config"; -import { orgs, resources, sites, Target, targets } from "@server/db"; +import { orgs, resources, sites, siteNetworks, siteResources, Target, targets } from "@server/db"; import { sanitize, encodePath, @@ -267,6 +267,35 @@ export async function getTraefikConfig( }); }); + // Query siteResources in HTTP mode with SSL enabled and aliases — cert generation / HTTPS edge + const siteResourcesWithFullDomain = await db + .select({ + siteResourceId: siteResources.siteResourceId, + fullDomain: siteResources.fullDomain, + mode: siteResources.mode + }) + .from(siteResources) + .innerJoin(siteNetworks, eq(siteResources.networkId, siteNetworks.networkId)) + .innerJoin(sites, eq(siteNetworks.siteId, sites.siteId)) + .where( + and( + eq(siteResources.enabled, true), + isNotNull(siteResources.fullDomain), + eq(siteResources.mode, "http"), + eq(siteResources.ssl, true), + or( + eq(sites.exitNodeId, exitNodeId), + and( + isNull(sites.exitNodeId), + sql`(${siteTypes.includes("local") ? 1 : 0} = 1)`, + eq(sites.type, "local"), + sql`(${build != "saas" ? 1 : 0} = 1)` + ) + ), + inArray(sites.type, siteTypes) + ) + ); + let validCerts: CertificateResult[] = []; if (privateConfig.getRawPrivateConfig().flags.use_pangolin_dns) { // create a list of all domains to get certs for @@ -276,6 +305,12 @@ export async function getTraefikConfig( domains.add(resource.fullDomain); } } + // Include siteResource aliases so pangolin-dns also fetches certs for them + for (const sr of siteResourcesWithFullDomain) { + if (sr.fullDomain) { + domains.add(sr.fullDomain); + } + } // get the valid certs for these domains validCerts = await getValidCertificatesForDomains(domains, true); // we are caching here because this is called often // logger.debug(`Valid certs for domains: ${JSON.stringify(validCerts)}`); @@ -867,6 +902,139 @@ export async function getTraefikConfig( } } + // Add Traefik routes for siteResource aliases (HTTP mode + SSL) so that + // Traefik generates TLS certificates for those domains even when no + // matching resource exists yet. + if (siteResourcesWithFullDomain.length > 0) { + // Build a set of domains already covered by normal resources + const existingFullDomains = new Set(); + for (const resource of resourcesMap.values()) { + if (resource.fullDomain) { + existingFullDomains.add(resource.fullDomain); + } + } + + for (const sr of siteResourcesWithFullDomain) { + if (!sr.fullDomain) continue; + + // Skip if this alias is already handled by a resource router + if (existingFullDomains.has(sr.fullDomain)) continue; + + const fullDomain = sr.fullDomain; + const srKey = `site-resource-cert-${sr.siteResourceId}`; + const siteResourceServiceName = `${srKey}-service`; + const siteResourceRouterName = `${srKey}-router`; + const siteResourceRewriteMiddlewareName = `${srKey}-rewrite`; + + const maintenancePort = config.getRawConfig().server.next_port; + const maintenanceHost = + config.getRawConfig().server.internal_hostname; + + if (!config_output.http.routers) { + config_output.http.routers = {}; + } + if (!config_output.http.services) { + config_output.http.services = {}; + } + if (!config_output.http.middlewares) { + config_output.http.middlewares = {}; + } + + // Service pointing at the internal maintenance/Next.js page + config_output.http.services[siteResourceServiceName] = { + loadBalancer: { + servers: [ + { + url: `http://${maintenanceHost}:${maintenancePort}` + } + ], + passHostHeader: true + } + }; + + // Middleware that rewrites any path to /maintenance-screen + config_output.http.middlewares[ + siteResourceRewriteMiddlewareName + ] = { + replacePathRegex: { + regex: "^/(.*)", + replacement: "/private-maintenance-screen" + } + }; + + // HTTP -> HTTPS redirect so the ACME challenge can be served + config_output.http.routers[ + `${siteResourceRouterName}-redirect` + ] = { + entryPoints: [ + config.getRawConfig().traefik.http_entrypoint + ], + middlewares: [redirectHttpsMiddlewareName], + service: siteResourceServiceName, + rule: `Host(\`${fullDomain}\`)`, + priority: 100 + }; + + // Determine TLS / cert-resolver configuration + let tls: any = {}; + if ( + !privateConfig.getRawPrivateConfig().flags.use_pangolin_dns + ) { + const domainParts = fullDomain.split("."); + const wildCard = + domainParts.length <= 2 + ? `*.${domainParts.join(".")}` + : `*.${domainParts.slice(1).join(".")}`; + + const globalDefaultResolver = + config.getRawConfig().traefik.cert_resolver; + const globalDefaultPreferWildcard = + config.getRawConfig().traefik.prefer_wildcard_cert; + + tls = { + certResolver: globalDefaultResolver, + ...(globalDefaultPreferWildcard + ? { domains: [{ main: wildCard }] } + : {}) + }; + } else { + // pangolin-dns: only add route if we already have a valid cert + const matchingCert = validCerts.find( + (cert) => cert.queriedDomain === fullDomain + ); + if (!matchingCert) { + logger.debug( + `No matching certificate found for siteResource alias: ${fullDomain}` + ); + continue; + } + } + + // HTTPS router — presence of this entry triggers cert generation + config_output.http.routers[siteResourceRouterName] = { + entryPoints: [ + config.getRawConfig().traefik.https_entrypoint + ], + service: siteResourceServiceName, + middlewares: [siteResourceRewriteMiddlewareName], + rule: `Host(\`${fullDomain}\`)`, + priority: 100, + tls + }; + + // Assets bypass router — lets Next.js static files load without rewrite + config_output.http.routers[`${siteResourceRouterName}-assets`] = { + entryPoints: [ + config.getRawConfig().traefik.https_entrypoint + ], + service: siteResourceServiceName, + rule: `Host(\`${fullDomain}\`) && (PathPrefix(\`/_next\`) || PathRegexp(\`^/__nextjs*\`))`, + priority: 101, + tls + }; + } + } + if (generateLoginPageRouters) { const exitNodeLoginPages = await db .select({ diff --git a/server/private/routers/hybrid.ts b/server/private/routers/hybrid.ts index f689df0a5..b3ef792d9 100644 --- a/server/private/routers/hybrid.ts +++ b/server/private/routers/hybrid.ts @@ -24,14 +24,8 @@ import { User, certificates, exitNodeOrgs, - RemoteExitNode, - olms, - newts, - clients, - sites, domains, orgDomains, - targets, loginPage, loginPageOrg, LoginPage, @@ -70,12 +64,9 @@ import { updateAndGenerateEndpointDestinations, updateSiteBandwidth } from "@server/routers/gerbil"; -import * as gerbil from "@server/routers/gerbil"; import logger from "@server/logger"; -import { decryptData } from "@server/lib/encryption"; +import { decrypt } from "@server/lib/crypto"; import config from "@server/lib/config"; -import privateConfig from "#private/lib/config"; -import * as fs from "fs"; import { exchangeSession } from "@server/routers/badger"; import { validateResourceSessionToken } from "@server/auth/sessions/resource"; import { checkExitNodeOrg, resolveExitNodes } from "#private/lib/exitNodes"; @@ -298,25 +289,11 @@ hybridRouter.get( } ); -let encryptionKeyHex = ""; -let encryptionKey: Buffer; -function loadEncryptData() { - if (encryptionKey) { - return; // already loaded - } - - encryptionKeyHex = - privateConfig.getRawPrivateConfig().server.encryption_key; - encryptionKey = Buffer.from(encryptionKeyHex, "hex"); -} - // Get valid certificates for given domains (supports wildcard certs) hybridRouter.get( "/certificates/domains", async (req: Request, res: Response, next: NextFunction) => { try { - loadEncryptData(); // Ensure encryption key is loaded - const parsed = getCertificatesByDomainsQuerySchema.safeParse( req.query ); @@ -447,13 +424,13 @@ hybridRouter.get( const result = filtered.map((cert) => { // Decrypt and save certificate file - const decryptedCert = decryptData( + const decryptedCert = decrypt( cert.certFile!, // is not null from query - encryptionKey + config.getRawConfig().server.secret! ); // Decrypt and save key file - const decryptedKey = decryptData(cert.keyFile!, encryptionKey); + const decryptedKey = decrypt(cert.keyFile!, config.getRawConfig().server.secret!); // Return only the certificate data without org information return { @@ -833,9 +810,12 @@ hybridRouter.get( ) ); - logger.debug(`User ${userId} has roles in org ${orgId}:`, userOrgRoleRows); + logger.debug( + `User ${userId} has roles in org ${orgId}:`, + userOrgRoleRows + ); - return response<{ roleId: number, roleName: string }[]>(res, { + return response<{ roleId: number; roleName: string }[]>(res, { data: userOrgRoleRows, success: true, error: false, diff --git a/server/private/routers/newt/handleConnectionLogMessage.ts b/server/private/routers/newt/handleConnectionLogMessage.ts index fb6ab3453..6355eb783 100644 --- a/server/private/routers/newt/handleConnectionLogMessage.ts +++ b/server/private/routers/newt/handleConnectionLogMessage.ts @@ -92,9 +92,14 @@ export const handleConnectionLogMessage: MessageHandler = async (context) => { return; } - // Look up the org for this site + // Look up the org for this site and check retention settings const [site] = await db - .select({ orgId: sites.orgId, orgSubnet: orgs.subnet }) + .select({ + orgId: sites.orgId, + orgSubnet: orgs.subnet, + settingsLogRetentionDaysConnection: + orgs.settingsLogRetentionDaysConnection + }) .from(sites) .innerJoin(orgs, eq(sites.orgId, orgs.orgId)) .where(eq(sites.siteId, newt.siteId)); @@ -108,6 +113,13 @@ export const handleConnectionLogMessage: MessageHandler = async (context) => { const orgId = site.orgId; + if (site.settingsLogRetentionDaysConnection === 0) { + logger.debug( + `Connection log retention is disabled for org ${orgId}, skipping` + ); + return; + } + // Extract the CIDR suffix (e.g. "/16") from the org subnet so we can // reconstruct the exact subnet string stored on each client record. const cidrSuffix = site.orgSubnet?.includes("/") diff --git a/server/private/routers/newt/handleRequestLogMessage.ts b/server/private/routers/newt/handleRequestLogMessage.ts new file mode 100644 index 000000000..42f1baf2c --- /dev/null +++ b/server/private/routers/newt/handleRequestLogMessage.ts @@ -0,0 +1,238 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + +import { db } from "@server/db"; +import { MessageHandler } from "@server/routers/ws"; +import { sites, Newt, orgs, clients, clientSitesAssociationsCache } from "@server/db"; +import { and, eq, inArray } from "drizzle-orm"; +import logger from "@server/logger"; +import { inflate } from "zlib"; +import { promisify } from "util"; +import { logRequestAudit } from "@server/routers/badger/logRequestAudit"; +import { getCountryCodeForIp } from "@server/lib/geoip"; + +export async function flushRequestLogToDb(): Promise { + return; +} + +const zlibInflate = promisify(inflate); + +interface HTTPRequestLogData { + requestId: string; + resourceId: number; // siteResourceId + timestamp: string; // ISO 8601 + method: string; + scheme: string; // "http" or "https" + host: string; + path: string; + rawQuery?: string; + userAgent?: string; + sourceAddr: string; // ip:port + tls: boolean; +} + +/** + * Decompress a base64-encoded zlib-compressed string into parsed JSON. + */ +async function decompressRequestLog( + compressed: string +): Promise { + const compressedBuffer = Buffer.from(compressed, "base64"); + const decompressed = await zlibInflate(compressedBuffer); + const jsonString = decompressed.toString("utf-8"); + const parsed = JSON.parse(jsonString); + + if (!Array.isArray(parsed)) { + throw new Error("Decompressed request log data is not an array"); + } + + return parsed; +} + +export const handleRequestLogMessage: MessageHandler = async (context) => { + const { message, client } = context; + const newt = client as Newt; + + if (!newt) { + logger.warn("Request log received but no newt client in context"); + return; + } + + if (!newt.siteId) { + logger.warn("Request log received but newt has no siteId"); + return; + } + + if (!message.data?.compressed) { + logger.warn("Request log message missing compressed data"); + return; + } + + // Look up the org for this site and check retention settings + const [site] = await db + .select({ + orgId: sites.orgId, + orgSubnet: orgs.subnet, + settingsLogRetentionDaysRequest: + orgs.settingsLogRetentionDaysRequest + }) + .from(sites) + .innerJoin(orgs, eq(sites.orgId, orgs.orgId)) + .where(eq(sites.siteId, newt.siteId)); + + if (!site) { + logger.warn( + `Request log received but site ${newt.siteId} not found in database` + ); + return; + } + + const orgId = site.orgId; + + if (site.settingsLogRetentionDaysRequest === 0) { + logger.debug( + `Request log retention is disabled for org ${orgId}, skipping` + ); + return; + } + + let entries: HTTPRequestLogData[]; + try { + entries = await decompressRequestLog(message.data.compressed); + } catch (error) { + logger.error("Failed to decompress request log data:", error); + return; + } + + if (entries.length === 0) { + return; + } + + logger.debug(`Request log entries: ${JSON.stringify(entries)}`); + + // Build a map from sourceIp → external endpoint string by joining clients + // with clientSitesAssociationsCache. The endpoint is the real-world IP:port + // of the client device and is used for GeoIP lookup. + const ipToEndpoint = new Map(); + + const cidrSuffix = site.orgSubnet?.includes("/") + ? site.orgSubnet.substring(site.orgSubnet.indexOf("/")) + : null; + + if (cidrSuffix) { + const uniqueSourceAddrs = new Set(); + for (const entry of entries) { + if (entry.sourceAddr) { + uniqueSourceAddrs.add(entry.sourceAddr); + } + } + + if (uniqueSourceAddrs.size > 0) { + const subnetQueries = Array.from(uniqueSourceAddrs).map((addr) => { + const ip = addr.includes(":") ? addr.split(":")[0] : addr; + return `${ip}${cidrSuffix}`; + }); + + const matchedClients = await db + .select({ + subnet: clients.subnet, + endpoint: clientSitesAssociationsCache.endpoint + }) + .from(clients) + .innerJoin( + clientSitesAssociationsCache, + and( + eq( + clientSitesAssociationsCache.clientId, + clients.clientId + ), + eq(clientSitesAssociationsCache.siteId, newt.siteId) + ) + ) + .where( + and( + eq(clients.orgId, orgId), + inArray(clients.subnet, subnetQueries) + ) + ); + + for (const c of matchedClients) { + if (c.endpoint) { + const ip = c.subnet.split("/")[0]; + ipToEndpoint.set(ip, c.endpoint); + } + } + } + } + + for (const entry of entries) { + if ( + !entry.requestId || + !entry.resourceId || + !entry.method || + !entry.scheme || + !entry.host || + !entry.path || + !entry.sourceAddr + ) { + logger.debug( + `Skipping request log entry with missing required fields: ${JSON.stringify(entry)}` + ); + continue; + } + + const originalRequestURL = + entry.scheme + + "://" + + entry.host + + entry.path + + (entry.rawQuery ? "?" + entry.rawQuery : ""); + + // Resolve the client's external endpoint for GeoIP lookup. + // sourceAddr is the WireGuard IP (possibly ip:port), so strip the port. + const sourceIp = entry.sourceAddr.includes(":") + ? entry.sourceAddr.split(":")[0] + : entry.sourceAddr; + const endpoint = ipToEndpoint.get(sourceIp); + let location: string | undefined; + if (endpoint) { + const endpointIp = endpoint.includes(":") + ? endpoint.split(":")[0] + : endpoint; + location = await getCountryCodeForIp(endpointIp); + } + + await logRequestAudit( + { + action: true, + reason: 108, + siteResourceId: entry.resourceId, + orgId, + location + }, + { + path: entry.path, + originalRequestURL, + scheme: entry.scheme, + host: entry.host, + method: entry.method, + tls: entry.tls, + requestIp: entry.sourceAddr + } + ); + } + + logger.debug( + `Buffered ${entries.length} request log entry/entries from newt ${newt.newtId} (site ${newt.siteId})` + ); +}; diff --git a/server/private/routers/newt/index.ts b/server/private/routers/newt/index.ts index 59d8e980a..94dfc8f05 100644 --- a/server/private/routers/newt/index.ts +++ b/server/private/routers/newt/index.ts @@ -12,3 +12,4 @@ */ export * from "./handleConnectionLogMessage"; +export * from "./handleRequestLogMessage"; diff --git a/server/private/routers/ssh/signSshKey.ts b/server/private/routers/ssh/signSshKey.ts index f929aeca5..82044c0ad 100644 --- a/server/private/routers/ssh/signSshKey.ts +++ b/server/private/routers/ssh/signSshKey.ts @@ -21,7 +21,7 @@ import { roles, roundTripMessageTracker, siteResources, - sites, + siteNetworks, userOrgs } from "@server/db"; import { logAccessAudit } from "#private/lib/logAccessAudit"; @@ -63,10 +63,12 @@ const bodySchema = z export type SignSshKeyResponse = { certificate: string; + messageIds: number[]; messageId: number; sshUsername: string; sshHost: string; resourceId: number; + siteIds: number[]; siteId: number; keyId: string; validPrincipals: string[]; @@ -260,10 +262,7 @@ export async function signSshKey( .update(userOrgs) .set({ pamUsername: usernameToUse }) .where( - and( - eq(userOrgs.orgId, orgId), - eq(userOrgs.userId, userId) - ) + and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId)) ); } else { usernameToUse = userOrg.pamUsername; @@ -395,21 +394,12 @@ export async function signSshKey( homedir = roleRows[0].sshCreateHomeDir ?? null; } - // get the site - const [newt] = await db - .select() - .from(newts) - .where(eq(newts.siteId, resource.siteId)) - .limit(1); + const sites = await db + .select({ siteId: siteNetworks.siteId }) + .from(siteNetworks) + .where(eq(siteNetworks.networkId, resource.networkId!)); - if (!newt) { - return next( - createHttpError( - HttpCode.INTERNAL_SERVER_ERROR, - "Site associated with resource not found" - ) - ); - } + const siteIds = sites.map((site) => site.siteId); // Sign the public key const now = BigInt(Math.floor(Date.now() / 1000)); @@ -423,43 +413,64 @@ export async function signSshKey( validBefore: now + validFor }); - const [message] = await db - .insert(roundTripMessageTracker) - .values({ - wsClientId: newt.newtId, - messageType: `newt/pam/connection`, - sentAt: Math.floor(Date.now() / 1000) - }) - .returning(); + const messageIds: number[] = []; + for (const siteId of siteIds) { + // get the site + const [newt] = await db + .select() + .from(newts) + .where(eq(newts.siteId, siteId)) + .limit(1); - if (!message) { - return next( - createHttpError( - HttpCode.INTERNAL_SERVER_ERROR, - "Failed to create message tracker entry" - ) - ); - } - - await sendToClient(newt.newtId, { - type: `newt/pam/connection`, - data: { - messageId: message.messageId, - orgId: orgId, - agentPort: resource.authDaemonPort ?? 22123, - externalAuthDaemon: resource.authDaemonMode === "remote", - agentHost: resource.destination, - caCert: caKeys.publicKeyOpenSSH, - username: usernameToUse, - niceId: resource.niceId, - metadata: { - sudoMode: sudoMode, - sudoCommands: parsedSudoCommands, - homedir: homedir, - groups: parsedGroups - } + if (!newt) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Site associated with resource not found" + ) + ); } - }); + + const [message] = await db + .insert(roundTripMessageTracker) + .values({ + wsClientId: newt.newtId, + messageType: `newt/pam/connection`, + sentAt: Math.floor(Date.now() / 1000) + }) + .returning(); + + if (!message) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Failed to create message tracker entry" + ) + ); + } + + messageIds.push(message.messageId); + + await sendToClient(newt.newtId, { + type: `newt/pam/connection`, + data: { + messageId: message.messageId, + orgId: orgId, + agentPort: resource.authDaemonPort ?? 22123, + externalAuthDaemon: resource.authDaemonMode === "remote", + agentHost: resource.destination, + caCert: caKeys.publicKeyOpenSSH, + username: usernameToUse, + niceId: resource.niceId, + metadata: { + sudoMode: sudoMode, + sudoCommands: parsedSudoCommands, + homedir: homedir, + groups: parsedGroups + } + } + }); + } const expiresIn = Number(validFor); // seconds @@ -480,7 +491,7 @@ export async function signSshKey( metadata: JSON.stringify({ resourceId: resource.siteResourceId, resource: resource.name, - siteId: resource.siteId, + siteIds: siteIds }) }); @@ -494,7 +505,7 @@ export async function signSshKey( : undefined, metadata: { resourceName: resource.name, - siteId: resource.siteId, + siteId: siteIds[0], sshUsername: usernameToUse, sshHost: sshHost }, @@ -505,11 +516,13 @@ export async function signSshKey( return response(res, { data: { certificate: cert.certificate, - messageId: message.messageId, + messageIds: messageIds, + messageId: messageIds[0], // just pick the first one for backward compatibility sshUsername: usernameToUse, sshHost: sshHost, resourceId: resource.siteResourceId, - siteId: resource.siteId, + siteIds: siteIds, + siteId: siteIds[0], // just pick the first one for backward compatibility keyId: cert.keyId, validPrincipals: cert.validPrincipals, validAfter: cert.validAfter.toISOString(), diff --git a/server/private/routers/ws/messageHandlers.ts b/server/private/routers/ws/messageHandlers.ts index 00a9a0ad6..b2553871e 100644 --- a/server/private/routers/ws/messageHandlers.ts +++ b/server/private/routers/ws/messageHandlers.ts @@ -18,12 +18,13 @@ import { } from "#private/routers/remoteExitNode"; import { MessageHandler } from "@server/routers/ws"; import { build } from "@server/build"; -import { handleConnectionLogMessage } from "#private/routers/newt"; +import { handleConnectionLogMessage, handleRequestLogMessage } from "#private/routers/newt"; export const messageHandlers: Record = { "remoteExitNode/register": handleRemoteExitNodeRegisterMessage, "remoteExitNode/ping": handleRemoteExitNodePingMessage, "newt/access-log": handleConnectionLogMessage, + "newt/request-log": handleRequestLogMessage, }; if (build != "saas") { diff --git a/server/routers/auditLogs/queryRequestAuditLog.ts b/server/routers/auditLogs/queryRequestAuditLog.ts index 176a9e5d3..000ec9815 100644 --- a/server/routers/auditLogs/queryRequestAuditLog.ts +++ b/server/routers/auditLogs/queryRequestAuditLog.ts @@ -1,8 +1,8 @@ -import { logsDb, primaryLogsDb, requestAuditLog, resources, db, primaryDb } from "@server/db"; +import { logsDb, primaryLogsDb, requestAuditLog, resources, siteResources, db, primaryDb } from "@server/db"; import { registry } from "@server/openApi"; import { NextFunction } from "express"; import { Request, Response } from "express"; -import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm"; +import { eq, gt, lt, and, count, desc, inArray, isNull, or } from "drizzle-orm"; import { OpenAPITags } from "@server/openApi"; import { z } from "zod"; import createHttpError from "http-errors"; @@ -92,7 +92,10 @@ function getWhere(data: Q) { lt(requestAuditLog.timestamp, data.timeEnd), eq(requestAuditLog.orgId, data.orgId), data.resourceId - ? eq(requestAuditLog.resourceId, data.resourceId) + ? or( + eq(requestAuditLog.resourceId, data.resourceId), + eq(requestAuditLog.siteResourceId, data.resourceId) + ) : undefined, data.actor ? eq(requestAuditLog.actor, data.actor) : undefined, data.method ? eq(requestAuditLog.method, data.method) : undefined, @@ -110,15 +113,16 @@ export function queryRequest(data: Q) { return primaryLogsDb .select({ id: requestAuditLog.id, - timestamp: requestAuditLog.timestamp, - orgId: requestAuditLog.orgId, - action: requestAuditLog.action, - reason: requestAuditLog.reason, - actorType: requestAuditLog.actorType, - actor: requestAuditLog.actor, - actorId: requestAuditLog.actorId, - resourceId: requestAuditLog.resourceId, - ip: requestAuditLog.ip, + timestamp: requestAuditLog.timestamp, + orgId: requestAuditLog.orgId, + action: requestAuditLog.action, + reason: requestAuditLog.reason, + actorType: requestAuditLog.actorType, + actor: requestAuditLog.actor, + actorId: requestAuditLog.actorId, + resourceId: requestAuditLog.resourceId, + siteResourceId: requestAuditLog.siteResourceId, + ip: requestAuditLog.ip, location: requestAuditLog.location, userAgent: requestAuditLog.userAgent, metadata: requestAuditLog.metadata, @@ -137,37 +141,73 @@ export function queryRequest(data: Q) { } async function enrichWithResourceDetails(logs: Awaited>) { - // If logs database is the same as main database, we can do a join - // Otherwise, we need to fetch resource details separately const resourceIds = logs .map(log => log.resourceId) .filter((id): id is number => id !== null && id !== undefined); - if (resourceIds.length === 0) { + const siteResourceIds = logs + .filter(log => log.resourceId == null && log.siteResourceId != null) + .map(log => log.siteResourceId) + .filter((id): id is number => id !== null && id !== undefined); + + if (resourceIds.length === 0 && siteResourceIds.length === 0) { return logs.map(log => ({ ...log, resourceName: null, resourceNiceId: null })); } - // Fetch resource details from main database - const resourceDetails = await primaryDb - .select({ - resourceId: resources.resourceId, - name: resources.name, - niceId: resources.niceId - }) - .from(resources) - .where(inArray(resources.resourceId, resourceIds)); + const resourceMap = new Map(); - // Create a map for quick lookup - const resourceMap = new Map( - resourceDetails.map(r => [r.resourceId, { name: r.name, niceId: r.niceId }]) - ); + if (resourceIds.length > 0) { + const resourceDetails = await primaryDb + .select({ + resourceId: resources.resourceId, + name: resources.name, + niceId: resources.niceId + }) + .from(resources) + .where(inArray(resources.resourceId, resourceIds)); + + for (const r of resourceDetails) { + resourceMap.set(r.resourceId, { name: r.name, niceId: r.niceId }); + } + } + + const siteResourceMap = new Map(); + + if (siteResourceIds.length > 0) { + const siteResourceDetails = await primaryDb + .select({ + siteResourceId: siteResources.siteResourceId, + name: siteResources.name, + niceId: siteResources.niceId + }) + .from(siteResources) + .where(inArray(siteResources.siteResourceId, siteResourceIds)); + + for (const r of siteResourceDetails) { + siteResourceMap.set(r.siteResourceId, { name: r.name, niceId: r.niceId }); + } + } // Enrich logs with resource details - return logs.map(log => ({ - ...log, - resourceName: log.resourceId ? resourceMap.get(log.resourceId)?.name ?? null : null, - resourceNiceId: log.resourceId ? resourceMap.get(log.resourceId)?.niceId ?? null : null - })); + return logs.map(log => { + if (log.resourceId != null) { + const details = resourceMap.get(log.resourceId); + return { + ...log, + resourceName: details?.name ?? null, + resourceNiceId: details?.niceId ?? null + }; + } else if (log.siteResourceId != null) { + const details = siteResourceMap.get(log.siteResourceId); + return { + ...log, + resourceId: log.siteResourceId, + resourceName: details?.name ?? null, + resourceNiceId: details?.niceId ?? null + }; + } + return { ...log, resourceName: null, resourceNiceId: null }; + }); } export function countRequestQuery(data: Q) { @@ -211,7 +251,8 @@ async function queryUniqueFilterAttributes( uniqueLocations, uniqueHosts, uniquePaths, - uniqueResources + uniqueResources, + uniqueSiteResources ] = await Promise.all([ primaryLogsDb .selectDistinct({ actor: requestAuditLog.actor }) @@ -239,6 +280,13 @@ async function queryUniqueFilterAttributes( }) .from(requestAuditLog) .where(baseConditions) + .limit(DISTINCT_LIMIT + 1), + primaryLogsDb + .selectDistinct({ + id: requestAuditLog.siteResourceId + }) + .from(requestAuditLog) + .where(and(baseConditions, isNull(requestAuditLog.resourceId))) .limit(DISTINCT_LIMIT + 1) ]); @@ -259,6 +307,10 @@ async function queryUniqueFilterAttributes( .map(row => row.id) .filter((id): id is number => id !== null); + const siteResourceIds = uniqueSiteResources + .map(row => row.id) + .filter((id): id is number => id !== null); + let resourcesWithNames: Array<{ id: number; name: string | null }> = []; if (resourceIds.length > 0) { @@ -270,10 +322,31 @@ async function queryUniqueFilterAttributes( .from(resources) .where(inArray(resources.resourceId, resourceIds)); - resourcesWithNames = resourceDetails.map(r => ({ - id: r.resourceId, - name: r.name - })); + resourcesWithNames = [ + ...resourcesWithNames, + ...resourceDetails.map(r => ({ + id: r.resourceId, + name: r.name + })) + ]; + } + + if (siteResourceIds.length > 0) { + const siteResourceDetails = await primaryDb + .select({ + siteResourceId: siteResources.siteResourceId, + name: siteResources.name + }) + .from(siteResources) + .where(inArray(siteResources.siteResourceId, siteResourceIds)); + + resourcesWithNames = [ + ...resourcesWithNames, + ...siteResourceDetails.map(r => ({ + id: r.siteResourceId, + name: r.name + })) + ]; } return { diff --git a/server/routers/auditLogs/types.ts b/server/routers/auditLogs/types.ts index 4c278cba5..972eebfe3 100644 --- a/server/routers/auditLogs/types.ts +++ b/server/routers/auditLogs/types.ts @@ -28,6 +28,7 @@ export type QueryRequestAuditLogResponse = { actor: string | null; actorId: string | null; resourceId: number | null; + siteResourceId: number | null; resourceNiceId: string | null; resourceName: string | null; ip: string | null; diff --git a/server/routers/badger/logRequestAudit.ts b/server/routers/badger/logRequestAudit.ts index 92d01332e..884fb7ae4 100644 --- a/server/routers/badger/logRequestAudit.ts +++ b/server/routers/badger/logRequestAudit.ts @@ -18,6 +18,7 @@ Reasons: 105 - Valid Password 106 - Valid email 107 - Valid SSO +108 - Connected Client 201 - Resource Not Found 202 - Resource Blocked @@ -38,6 +39,7 @@ const auditLogBuffer: Array<{ metadata: any; action: boolean; resourceId?: number; + siteResourceId?: number; reason: number; location?: string; originalRequestURL: string; @@ -186,6 +188,7 @@ export async function logRequestAudit( action: boolean; reason: number; resourceId?: number; + siteResourceId?: number; orgId?: string; location?: string; user?: { username: string; userId: string }; @@ -262,6 +265,7 @@ export async function logRequestAudit( metadata: sanitizeString(metadata), action: data.action, resourceId: data.resourceId, + siteResourceId: data.siteResourceId, reason: data.reason, location: sanitizeString(data.location), originalRequestURL: sanitizeString(body.originalRequestURL) ?? "", diff --git a/server/routers/newt/buildConfiguration.ts b/server/routers/newt/buildConfiguration.ts index afb196152..6f0bad6f1 100644 --- a/server/routers/newt/buildConfiguration.ts +++ b/server/routers/newt/buildConfiguration.ts @@ -4,8 +4,10 @@ import { clientSitesAssociationsCache, db, ExitNode, + networks, resources, Site, + siteNetworks, siteResources, targetHealthCheck, targets @@ -137,11 +139,14 @@ export async function buildClientConfigurationForNewtClient( // Filter out any null values from peers that didn't have an olm const validPeers = peers.filter((peer) => peer !== null); - // Get all enabled site resources for this site + // Get all enabled site resources for this site by joining through siteNetworks and networks const allSiteResources = await db .select() .from(siteResources) - .where(eq(siteResources.siteId, siteId)); + .innerJoin(networks, eq(siteResources.networkId, networks.networkId)) + .innerJoin(siteNetworks, eq(networks.networkId, siteNetworks.networkId)) + .where(eq(siteNetworks.siteId, siteId)) + .then((rows) => rows.map((r) => r.siteResources)); const targetsToSend: SubnetProxyTargetV2[] = []; @@ -168,7 +173,7 @@ export async function buildClientConfigurationForNewtClient( ) ); - const resourceTargets = generateSubnetProxyTargetV2( + const resourceTargets = await generateSubnetProxyTargetV2( resource, resourceClients ); diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleNewtGetConfigMessage.ts similarity index 92% rename from server/routers/newt/handleGetConfigMessage.ts rename to server/routers/newt/handleNewtGetConfigMessage.ts index 9c67f53ee..787151a5a 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleNewtGetConfigMessage.ts @@ -10,7 +10,7 @@ import { convertTargetsIfNessicary } from "../client/targets"; import { canCompress } from "@server/lib/clientVersionChecks"; import config from "@server/lib/config"; -export const handleGetConfigMessage: MessageHandler = async (context) => { +export const handleNewtGetConfigMessage: MessageHandler = async (context) => { const { message, client, sendToClient } = context; const newt = client as Newt; @@ -56,7 +56,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 5) { logger.warn( - `Site last hole punch is too old; skipping this register. The site is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?` + `Site last hole punch is too old; skipping this register. The site is failing to hole punch and identify its network address with the server. Can the site reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?` ); return; } @@ -113,7 +113,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { exitNode ); - const targetsToSend = await convertTargetsIfNessicary(newt.newtId, targets); + const targetsToSend = await convertTargetsIfNessicary(newt.newtId, targets); // for backward compatibility with old newt versions that don't support the new target format return { message: { diff --git a/server/routers/newt/handleRequestLogMessage.ts b/server/routers/newt/handleRequestLogMessage.ts new file mode 100644 index 000000000..190020ad1 --- /dev/null +++ b/server/routers/newt/handleRequestLogMessage.ts @@ -0,0 +1,9 @@ +import { MessageHandler } from "@server/routers/ws"; + +export async function flushRequestLogToDb(): Promise { + return; +} + +export const handleRequestLogMessage: MessageHandler = async (context) => { + return; +}; \ No newline at end of file diff --git a/server/routers/newt/index.ts b/server/routers/newt/index.ts index 33b5caf7c..fe6998722 100644 --- a/server/routers/newt/index.ts +++ b/server/routers/newt/index.ts @@ -2,11 +2,12 @@ export * from "./createNewt"; export * from "./getNewtToken"; export * from "./handleNewtRegisterMessage"; export * from "./handleReceiveBandwidthMessage"; -export * from "./handleGetConfigMessage"; +export * from "./handleNewtGetConfigMessage"; export * from "./handleSocketMessages"; export * from "./handleNewtPingRequestMessage"; export * from "./handleApplyBlueprintMessage"; export * from "./handleNewtPingMessage"; export * from "./handleNewtDisconnectingMessage"; export * from "./handleConnectionLogMessage"; +export * from "./handleRequestLogMessage"; export * from "./registerNewt"; diff --git a/server/routers/olm/buildConfiguration.ts b/server/routers/olm/buildConfiguration.ts index bc2611b1c..4182725d3 100644 --- a/server/routers/olm/buildConfiguration.ts +++ b/server/routers/olm/buildConfiguration.ts @@ -4,6 +4,8 @@ import { clientSitesAssociationsCache, db, exitNodes, + networks, + siteNetworks, siteResources, sites } from "@server/db"; @@ -59,9 +61,17 @@ export async function buildSiteConfigurationForOlmClient( clientSiteResourcesAssociationsCache.siteResourceId ) ) + .innerJoin( + networks, + eq(siteResources.networkId, networks.networkId) + ) + .innerJoin( + siteNetworks, + eq(networks.networkId, siteNetworks.networkId) + ) .where( and( - eq(siteResources.siteId, site.siteId), + eq(siteNetworks.siteId, site.siteId), eq( clientSiteResourcesAssociationsCache.clientId, client.clientId @@ -69,6 +79,7 @@ export async function buildSiteConfigurationForOlmClient( ) ); + if (jitMode) { // Add site configuration to the array siteConfigurations.push({ diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 01495de3b..a4a62973d 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -17,7 +17,6 @@ import { getUserDeviceName } from "@server/db/names"; import { buildSiteConfigurationForOlmClient } from "./buildConfiguration"; import { OlmErrorCodes, sendOlmError } from "./error"; import { handleFingerprintInsertion } from "./fingerprintingUtils"; -import { Alias } from "@server/lib/ip"; import { build } from "@server/build"; import { canCompress } from "@server/lib/clientVersionChecks"; import config from "@server/lib/config"; diff --git a/server/routers/olm/handleOlmServerInitAddPeerHandshake.ts b/server/routers/olm/handleOlmServerInitAddPeerHandshake.ts index 54badb2dc..0eda41e04 100644 --- a/server/routers/olm/handleOlmServerInitAddPeerHandshake.ts +++ b/server/routers/olm/handleOlmServerInitAddPeerHandshake.ts @@ -4,10 +4,12 @@ import { db, exitNodes, Site, - siteResources + siteNetworks, + siteResources, + sites } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; -import { clients, Olm, sites } from "@server/db"; +import { clients, Olm } from "@server/db"; import { and, eq, or } from "drizzle-orm"; import logger from "@server/logger"; import { initPeerAddHandshake } from "./peers"; @@ -44,20 +46,31 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async ( const { siteId, resourceId, chainId } = message.data; - let site: Site | null = null; + const sendCancel = async () => { + await sendToClient( + olm.olmId, + { + type: "olm/wg/peer/chain/cancel", + data: { chainId } + }, + { incrementConfigVersion: false } + ).catch((error) => { + logger.warn(`Error sending message:`, error); + }); + }; + + let sitesToProcess: Site[] = []; + if (siteId) { - // get the site const [siteRes] = await db .select() .from(sites) .where(eq(sites.siteId, siteId)) .limit(1); if (siteRes) { - site = siteRes; + sitesToProcess = [siteRes]; } - } - - if (resourceId && !site) { + } else if (resourceId) { const resources = await db .select() .from(siteResources) @@ -72,27 +85,17 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async ( ); if (!resources || resources.length === 0) { - logger.error(`handleOlmServerPeerAddMessage: Resource not found`); - // cancel the request from the olm side to not keep doing this - await sendToClient( - olm.olmId, - { - type: "olm/wg/peer/chain/cancel", - data: { - chainId - } - }, - { incrementConfigVersion: false } - ).catch((error) => { - logger.warn(`Error sending message:`, error); - }); + logger.error( + `handleOlmServerInitAddPeerHandshake: Resource not found` + ); + await sendCancel(); return; } if (resources.length > 1) { // error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches logger.error( - `handleOlmServerPeerAddMessage: Multiple resources found matching the criteria` + `handleOlmServerInitAddPeerHandshake: Multiple resources found matching the criteria` ); return; } @@ -117,125 +120,120 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async ( if (currentResourceAssociationCaches.length === 0) { logger.error( - `handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}` + `handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}` ); - // cancel the request from the olm side to not keep doing this - await sendToClient( - olm.olmId, - { - type: "olm/wg/peer/chain/cancel", - data: { - chainId - } - }, - { incrementConfigVersion: false } - ).catch((error) => { - logger.warn(`Error sending message:`, error); - }); + await sendCancel(); return; } - const siteIdFromResource = resource.siteId; - - // get the site - const [siteRes] = await db - .select() - .from(sites) - .where(eq(sites.siteId, siteIdFromResource)); - if (!siteRes) { + if (!resource.networkId) { logger.error( - `handleOlmServerPeerAddMessage: Site with ID ${site} not found` + `handleOlmServerInitAddPeerHandshake: Resource ${resource.siteResourceId} has no network` ); + await sendCancel(); return; } - site = siteRes; + // Get all sites associated with this resource's network via siteNetworks + const siteRows = await db + .select({ siteId: siteNetworks.siteId }) + .from(siteNetworks) + .where(eq(siteNetworks.networkId, resource.networkId)); + + if (!siteRows || siteRows.length === 0) { + logger.error( + `handleOlmServerInitAddPeerHandshake: No sites found for resource ${resource.siteResourceId}` + ); + await sendCancel(); + return; + } + + // Fetch full site objects for all network members + const foundSites = await Promise.all( + siteRows.map(async ({ siteId: sid }) => { + const [s] = await db + .select() + .from(sites) + .where(eq(sites.siteId, sid)) + .limit(1); + return s ?? null; + }) + ); + + sitesToProcess = foundSites.filter((s): s is Site => s !== null); } - if (!site) { - logger.error(`handleOlmServerPeerAddMessage: Site not found`); + if (sitesToProcess.length === 0) { + logger.error( + `handleOlmServerInitAddPeerHandshake: No sites to process` + ); + await sendCancel(); return; } - // check if the client can access this site using the cache - const currentSiteAssociationCaches = await db - .select() - .from(clientSitesAssociationsCache) - .where( - and( - eq(clientSitesAssociationsCache.clientId, client.clientId), - eq(clientSitesAssociationsCache.siteId, site.siteId) - ) - ); + let handshakeInitiated = false; - if (currentSiteAssociationCaches.length === 0) { - logger.error( - `handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}` - ); - // cancel the request from the olm side to not keep doing this - await sendToClient( - olm.olmId, + for (const site of sitesToProcess) { + // Check if the client can access this site using the cache + const currentSiteAssociationCaches = await db + .select() + .from(clientSitesAssociationsCache) + .where( + and( + eq(clientSitesAssociationsCache.clientId, client.clientId), + eq(clientSitesAssociationsCache.siteId, site.siteId) + ) + ); + + if (currentSiteAssociationCaches.length === 0) { + logger.warn( + `handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to site ${site.siteId}, skipping` + ); + continue; + } + + if (!site.exitNodeId) { + logger.error( + `handleOlmServerInitAddPeerHandshake: Site ${site.siteId} has no exit node, skipping` + ); + continue; + } + + const [exitNode] = await db + .select() + .from(exitNodes) + .where(eq(exitNodes.exitNodeId, site.exitNodeId)); + + if (!exitNode) { + logger.error( + `handleOlmServerInitAddPeerHandshake: Exit node not found for site ${site.siteId}, skipping` + ); + continue; + } + + // Trigger the peer add handshake — if the peer was already added this will be a no-op + await initPeerAddHandshake( + client.clientId, { - type: "olm/wg/peer/chain/cancel", - data: { - chainId + siteId: site.siteId, + exitNode: { + publicKey: exitNode.publicKey, + endpoint: exitNode.endpoint } }, - { incrementConfigVersion: false } - ).catch((error) => { - logger.warn(`Error sending message:`, error); - }); - return; - } - - if (!site.exitNodeId) { - logger.error( - `handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node` - ); - // cancel the request from the olm side to not keep doing this - await sendToClient( olm.olmId, - { - type: "olm/wg/peer/chain/cancel", - data: { - chainId - } - }, - { incrementConfigVersion: false } - ).catch((error) => { - logger.warn(`Error sending message:`, error); - }); - return; - } - - // get the exit node from the side - const [exitNode] = await db - .select() - .from(exitNodes) - .where(eq(exitNodes.exitNodeId, site.exitNodeId)); - - if (!exitNode) { - logger.error( - `handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node` + chainId ); - return; + + handshakeInitiated = true; } - // also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch - // if it has already been added this will be a no-op - await initPeerAddHandshake( - // this will kick off the add peer process for the client - client.clientId, - { - siteId: site.siteId, - exitNode: { - publicKey: exitNode.publicKey, - endpoint: exitNode.endpoint - } - }, - olm.olmId, - chainId - ); + if (!handshakeInitiated) { + logger.error( + `handleOlmServerInitAddPeerHandshake: No accessible sites with valid exit nodes found, cancelling chain` + ); + await sendCancel(); + } return; -}; +}; \ No newline at end of file diff --git a/server/routers/olm/handleOlmServerPeerAddMessage.ts b/server/routers/olm/handleOlmServerPeerAddMessage.ts index 64284f493..5f46ea84c 100644 --- a/server/routers/olm/handleOlmServerPeerAddMessage.ts +++ b/server/routers/olm/handleOlmServerPeerAddMessage.ts @@ -1,43 +1,25 @@ import { - Client, clientSiteResourcesAssociationsCache, db, - ExitNode, - Org, - orgs, - roleClients, - roles, + networks, + siteNetworks, siteResources, - Transaction, - userClients, - userOrgs, - users } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; import { clients, clientSitesAssociationsCache, - exitNodes, Olm, - olms, sites } from "@server/db"; import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm"; -import { addPeer, deletePeer } from "../newt/peers"; import logger from "@server/logger"; -import { listExitNodes } from "#dynamic/lib/exitNodes"; import { generateAliasConfig, - getNextAvailableClientSubnet } from "@server/lib/ip"; import { generateRemoteSubnets } from "@server/lib/ip"; -import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations"; -import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy"; -import { validateSessionToken } from "@server/auth/sessions/app"; -import config from "@server/lib/config"; import { addPeer as newtAddPeer, - deletePeer as newtDeletePeer } from "@server/routers/newt/peers"; export const handleOlmServerPeerAddMessage: MessageHandler = async ( @@ -153,13 +135,21 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async ( clientSiteResourcesAssociationsCache.siteResourceId ) ) - .where( + .innerJoin( + networks, + eq(siteResources.networkId, networks.networkId) + ) + .innerJoin( + siteNetworks, and( - eq(siteResources.siteId, site.siteId), - eq( - clientSiteResourcesAssociationsCache.clientId, - client.clientId - ) + eq(networks.networkId, siteNetworks.networkId), + eq(siteNetworks.siteId, site.siteId) + ) + ) + .where( + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId ) ); diff --git a/server/routers/resource/getUserResources.ts b/server/routers/resource/getUserResources.ts index 802fffb1b..1722a7993 100644 --- a/server/routers/resource/getUserResources.ts +++ b/server/routers/resource/getUserResources.ts @@ -145,7 +145,7 @@ export async function getUserResources( niceId: string; destination: string; mode: string; - protocol: string | null; + scheme: string | null; enabled: boolean; alias: string | null; aliasAddress: string | null; @@ -158,7 +158,7 @@ export async function getUserResources( niceId: siteResources.niceId, destination: siteResources.destination, mode: siteResources.mode, - protocol: siteResources.protocol, + scheme: siteResources.scheme, enabled: siteResources.enabled, alias: siteResources.alias, aliasAddress: siteResources.aliasAddress @@ -242,7 +242,7 @@ export async function getUserResources( name: siteResource.name, destination: siteResource.destination, mode: siteResource.mode, - protocol: siteResource.protocol, + protocol: siteResource.scheme, enabled: siteResource.enabled, alias: siteResource.alias, aliasAddress: siteResource.aliasAddress, @@ -291,7 +291,7 @@ export type GetUserResourcesResponse = { enabled: boolean; alias: string | null; aliasAddress: string | null; - type: 'site'; + type: "site"; }>; }; }; diff --git a/server/routers/site/deleteSite.ts b/server/routers/site/deleteSite.ts index 587572535..344f6b4e3 100644 --- a/server/routers/site/deleteSite.ts +++ b/server/routers/site/deleteSite.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db, Site, siteResources } from "@server/db"; +import { db, Site, siteNetworks, siteResources } from "@server/db"; import { newts, newtSessions, sites } from "@server/db"; import { eq } from "drizzle-orm"; import response from "@server/lib/response"; @@ -71,18 +71,23 @@ export async function deleteSite( await deletePeer(site.exitNodeId!, site.pubKey); } } else if (site.type == "newt") { - // delete all of the site resources on this site - const siteResourcesOnSite = trx - .delete(siteResources) - .where(eq(siteResources.siteId, siteId)) - .returning(); + const networks = await trx + .select({ networkId: siteNetworks.networkId }) + .from(siteNetworks) + .where(eq(siteNetworks.siteId, siteId)); // loop through them - for (const removedSiteResource of await siteResourcesOnSite) { - await rebuildClientAssociationsFromSiteResource( - removedSiteResource, - trx - ); + for (const network of await networks) { + const [siteResource] = await trx + .select() + .from(siteResources) + .where(eq(siteResources.networkId, network.networkId)); + if (siteResource) { + await rebuildClientAssociationsFromSiteResource( + siteResource, + trx + ); + } } // get the newt on the site by querying the newt table for siteId diff --git a/server/routers/siteResource/createSiteResource.ts b/server/routers/siteResource/createSiteResource.ts index 1485a4192..9a7d632fd 100644 --- a/server/routers/siteResource/createSiteResource.ts +++ b/server/routers/siteResource/createSiteResource.ts @@ -5,6 +5,8 @@ import { orgs, roles, roleSiteResources, + siteNetworks, + networks, SiteResource, siteResources, sites, @@ -17,17 +19,18 @@ import { portRangeStringSchema } from "@server/lib/ip"; import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; -import { tierMatrix } from "@server/lib/billing/tierMatrix"; +import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix"; import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; import response from "@server/lib/response"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; import HttpCode from "@server/types/HttpCode"; -import { and, eq } from "drizzle-orm"; +import { and, eq, inArray } from "drizzle-orm"; import { NextFunction, Request, Response } from "express"; import createHttpError from "http-errors"; import { z } from "zod"; import { fromError } from "zod-validation-error"; +import { validateAndConstructDomain } from "@server/lib/domainUtils"; const createSiteResourceParamsSchema = z.strictObject({ orgId: z.string() @@ -36,11 +39,12 @@ const createSiteResourceParamsSchema = z.strictObject({ const createSiteResourceSchema = z .strictObject({ name: z.string().min(1).max(255), - mode: z.enum(["host", "cidr", "port"]), - siteId: z.int(), - // protocol: z.enum(["tcp", "udp"]).optional(), + mode: z.enum(["host", "cidr", "http"]), + ssl: z.boolean().optional(), // only used for http mode + scheme: z.enum(["http", "https"]).optional(), + siteIds: z.array(z.int()), // proxyPort: z.int().positive().optional(), - // destinationPort: z.int().positive().optional(), + destinationPort: z.int().positive().optional(), destination: z.string().min(1), enabled: z.boolean().default(true), alias: z @@ -57,20 +61,24 @@ const createSiteResourceSchema = z udpPortRangeString: portRangeStringSchema, disableIcmp: z.boolean().optional(), authDaemonPort: z.int().positive().optional(), - authDaemonMode: z.enum(["site", "remote"]).optional() + authDaemonMode: z.enum(["site", "remote"]).optional(), + domainId: z.string().optional(), // only used for http mode, we need this to verify the alias is unique within the org + subdomain: z.string().optional() // only used for http mode, we need this to verify the alias is unique within the org }) .strict() .refine( (data) => { if (data.mode === "host") { - // Check if it's a valid IP address using zod (v4 or v6) - const isValidIP = z - // .union([z.ipv4(), z.ipv6()]) - .union([z.ipv4()]) // for now lets just do ipv4 until we verify ipv6 works everywhere - .safeParse(data.destination).success; + if (data.mode == "host") { + // Check if it's a valid IP address using zod (v4 or v6) + const isValidIP = z + // .union([z.ipv4(), z.ipv6()]) + .union([z.ipv4()]) // for now lets just do ipv4 until we verify ipv6 works everywhere + .safeParse(data.destination).success; - if (isValidIP) { - return true; + if (isValidIP) { + return true; + } } // Check if it's a valid domain (hostname pattern, TLD not required) @@ -105,6 +113,21 @@ const createSiteResourceSchema = z { message: "Destination must be a valid CIDR notation for cidr mode" } + ) + .refine( + (data) => { + if (data.mode !== "http") return true; + return ( + data.scheme !== undefined && + data.destinationPort !== undefined && + data.destinationPort >= 1 && + data.destinationPort <= 65535 + ); + }, + { + message: + "HTTP mode requires scheme (http or https) and a valid destination port" + } ); export type CreateSiteResourceBody = z.infer; @@ -159,13 +182,14 @@ export async function createSiteResource( const { orgId } = parsedParams.data; const { name, - siteId, + siteIds, mode, - // protocol, + scheme, // proxyPort, - // destinationPort, + destinationPort, destination, enabled, + ssl, alias, userIds, roleIds, @@ -174,18 +198,36 @@ export async function createSiteResource( udpPortRangeString, disableIcmp, authDaemonPort, - authDaemonMode + authDaemonMode, + domainId, + subdomain } = parsedBody.data; + if (mode == "http") { + const hasHttpFeature = await isLicensedOrSubscribed( + orgId, + tierMatrix[TierFeature.HTTPPrivateResources] + ); + if (!hasHttpFeature) { + return next( + createHttpError( + HttpCode.FORBIDDEN, + "HTTP private resources are not included in your current plan. Please upgrade." + ) + ); + } + } + // Verify the site exists and belongs to the org - const [site] = await db + const sitesToAssign = await db .select() .from(sites) - .where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId))) - .limit(1); + .where(and(inArray(sites.siteId, siteIds), eq(sites.orgId, orgId))); - if (!site) { - return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); + if (sitesToAssign.length !== siteIds.length) { + return next( + createHttpError(HttpCode.NOT_FOUND, "Some site not found") + ); } const [org] = await db @@ -226,29 +268,50 @@ export async function createSiteResource( ); } - // // check if resource with same protocol and proxy port already exists (only for port mode) - // if (mode === "port" && protocol && proxyPort) { - // const [existingResource] = await db - // .select() - // .from(siteResources) - // .where( - // and( - // eq(siteResources.siteId, siteId), - // eq(siteResources.orgId, orgId), - // eq(siteResources.protocol, protocol), - // eq(siteResources.proxyPort, proxyPort) - // ) - // ) - // .limit(1); - // if (existingResource && existingResource.siteResourceId) { - // return next( - // createHttpError( - // HttpCode.CONFLICT, - // "A resource with the same protocol and proxy port already exists" - // ) - // ); - // } - // } + if (domainId && alias) { + // throw an error because we can only have one or the other + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Alias and domain cannot both be set. Please choose one or the other." + ) + ); + } + + let fullDomain: string | null = null; + let finalSubdomain: string | null = null; + if (domainId) { + // Validate domain and construct full domain + const domainResult = await validateAndConstructDomain( + domainId, + orgId, + subdomain + ); + + if (!domainResult.success) { + return next( + createHttpError(HttpCode.BAD_REQUEST, domainResult.error) + ); + } + + fullDomain = domainResult.fullDomain; + finalSubdomain = domainResult.subdomain; + + // make sure the full domain is unique + const existingResource = await db + .select() + .from(siteResources) + .where(eq(siteResources.fullDomain, fullDomain)); + + if (existingResource.length > 0) { + return next( + createHttpError( + HttpCode.CONFLICT, + "Resource with that domain already exists" + ) + ); + } + } // make sure the alias is unique within the org if provided if (alias) { @@ -280,27 +343,49 @@ export async function createSiteResource( const niceId = await getUniqueSiteResourceName(orgId); let aliasAddress: string | null = null; - if (mode == "host") { - // we can only have an alias on a host + if (mode === "host" || mode === "http") { aliasAddress = await getNextAvailableAliasAddress(orgId); } let newSiteResource: SiteResource | undefined; await db.transaction(async (trx) => { + const [network] = await trx + .insert(networks) + .values({ + scope: "resource", + orgId: orgId + }) + .returning(); + + if (!network) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + `Failed to create network` + ) + ); + } + // Create the site resource const insertValues: typeof siteResources.$inferInsert = { - siteId, niceId, orgId, name, - mode: mode as "host" | "cidr", + mode, + ssl, + networkId: network.networkId, destination, + scheme, + destinationPort, enabled, - alias, + alias: alias ? alias.trim() : null, aliasAddress, tcpPortRangeString, udpPortRangeString, - disableIcmp + disableIcmp, + domainId, + subdomain: finalSubdomain, + fullDomain }; if (isLicensedSshPam) { if (authDaemonPort !== undefined) @@ -317,6 +402,13 @@ export async function createSiteResource( //////////////////// update the associations //////////////////// + for (const siteId of siteIds) { + await trx.insert(siteNetworks).values({ + siteId: siteId, + networkId: network.networkId + }); + } + const [adminRole] = await trx .select() .from(roles) @@ -359,16 +451,21 @@ export async function createSiteResource( ); } - const [newt] = await trx - .select() - .from(newts) - .where(eq(newts.siteId, site.siteId)) - .limit(1); + for (const siteToAssign of sitesToAssign) { + const [newt] = await trx + .select() + .from(newts) + .where(eq(newts.siteId, siteToAssign.siteId)) + .limit(1); - if (!newt) { - return next( - createHttpError(HttpCode.NOT_FOUND, "Newt not found") - ); + if (!newt) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Newt not found for site ${siteToAssign.siteId}` + ) + ); + } } await rebuildClientAssociationsFromSiteResource( @@ -387,7 +484,7 @@ export async function createSiteResource( } logger.info( - `Created site resource ${newSiteResource.siteResourceId} for site ${siteId}` + `Created site resource ${newSiteResource.siteResourceId} for org ${orgId}` ); return response(res, { diff --git a/server/routers/siteResource/deleteSiteResource.ts b/server/routers/siteResource/deleteSiteResource.ts index 5b50b0ea3..8d08d545d 100644 --- a/server/routers/siteResource/deleteSiteResource.ts +++ b/server/routers/siteResource/deleteSiteResource.ts @@ -70,17 +70,18 @@ export async function deleteSiteResource( .where(and(eq(siteResources.siteResourceId, siteResourceId))) .returning(); - const [newt] = await trx - .select() - .from(newts) - .where(eq(newts.siteId, removedSiteResource.siteId)) - .limit(1); + // not sure why this is here... + // const [newt] = await trx + // .select() + // .from(newts) + // .where(eq(newts.siteId, removedSiteResource.siteId)) + // .limit(1); - if (!newt) { - return next( - createHttpError(HttpCode.NOT_FOUND, "Newt not found") - ); - } + // if (!newt) { + // return next( + // createHttpError(HttpCode.NOT_FOUND, "Newt not found") + // ); + // } await rebuildClientAssociationsFromSiteResource( removedSiteResource, diff --git a/server/routers/siteResource/getSiteResource.ts b/server/routers/siteResource/getSiteResource.ts index be28d36e4..2e3dfe87b 100644 --- a/server/routers/siteResource/getSiteResource.ts +++ b/server/routers/siteResource/getSiteResource.ts @@ -17,38 +17,34 @@ const getSiteResourceParamsSchema = z.strictObject({ .transform((val) => (val ? Number(val) : undefined)) .pipe(z.int().positive().optional()) .optional(), - siteId: z.string().transform(Number).pipe(z.int().positive()), niceId: z.string().optional(), orgId: z.string() }); async function query( siteResourceId?: number, - siteId?: number, niceId?: string, orgId?: string ) { - if (siteResourceId && siteId && orgId) { + if (siteResourceId && orgId) { const [siteResource] = await db .select() .from(siteResources) .where( and( eq(siteResources.siteResourceId, siteResourceId), - eq(siteResources.siteId, siteId), eq(siteResources.orgId, orgId) ) ) .limit(1); return siteResource; - } else if (niceId && siteId && orgId) { + } else if (niceId && orgId) { const [siteResource] = await db .select() .from(siteResources) .where( and( eq(siteResources.niceId, niceId), - eq(siteResources.siteId, siteId), eq(siteResources.orgId, orgId) ) ) @@ -84,7 +80,6 @@ registry.registerPath({ request: { params: z.object({ niceId: z.string(), - siteId: z.number(), orgId: z.string() }) }, @@ -107,10 +102,10 @@ export async function getSiteResource( ); } - const { siteResourceId, siteId, niceId, orgId } = parsedParams.data; + const { siteResourceId, niceId, orgId } = parsedParams.data; // Get the site resource - const siteResource = await query(siteResourceId, siteId, niceId, orgId); + const siteResource = await query(siteResourceId, niceId, orgId); if (!siteResource) { return next( diff --git a/server/routers/siteResource/listAllSiteResourcesByOrg.ts b/server/routers/siteResource/listAllSiteResourcesByOrg.ts index 3320aa3b7..8750e7516 100644 --- a/server/routers/siteResource/listAllSiteResourcesByOrg.ts +++ b/server/routers/siteResource/listAllSiteResourcesByOrg.ts @@ -1,4 +1,4 @@ -import { db, SiteResource, siteResources, sites } from "@server/db"; +import { db, DB_TYPE, SiteResource, siteNetworks, siteResources, sites } from "@server/db"; import response from "@server/lib/response"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; @@ -41,12 +41,12 @@ const listAllSiteResourcesByOrgQuerySchema = z.object({ }), query: z.string().optional(), mode: z - .enum(["host", "cidr"]) + .enum(["host", "cidr", "http"]) .optional() .catch(undefined) .openapi({ type: "string", - enum: ["host", "cidr"], + enum: ["host", "cidr", "http"], description: "Filter site resources by mode" }), sort_by: z @@ -73,22 +73,58 @@ const listAllSiteResourcesByOrgQuerySchema = z.object({ export type ListAllSiteResourcesByOrgResponse = PaginatedResponse<{ siteResources: (SiteResource & { - siteName: string; - siteNiceId: string; - siteAddress: string | null; + siteOnlines: boolean[]; + siteIds: number[]; + siteNames: string[]; + siteNiceIds: string[]; + siteAddresses: (string | null)[]; })[]; }>; +/** + * Returns an aggregation expression compatible with both SQLite and PostgreSQL. + * - SQLite: json_group_array(col) → returns a JSON array string, parsed after fetch + * - PostgreSQL: array_agg(col) → returns a native array + */ +function aggCol(column: any) { + if (DB_TYPE === "sqlite") { + return sql`json_group_array(${column})`; + } + return sql`array_agg(${column})`; +} + +/** + * For SQLite the aggregated columns come back as JSON strings; parse them into + * proper arrays. For PostgreSQL the driver already returns native arrays, so + * the row is returned unchanged. + */ +function transformSiteResourceRow(row: any) { + if (DB_TYPE !== "sqlite") { + return row; + } + return { + ...row, + siteNames: JSON.parse(row.siteNames) as string[], + siteNiceIds: JSON.parse(row.siteNiceIds) as string[], + siteIds: JSON.parse(row.siteIds) as number[], + siteAddresses: JSON.parse(row.siteAddresses) as (string | null)[], + // SQLite stores booleans as 0/1 integers + siteOnlines: (JSON.parse(row.siteOnlines) as (0 | 1)[]).map( + (v) => v === 1 + ) as boolean[] + }; +} + function querySiteResourcesBase() { return db .select({ siteResourceId: siteResources.siteResourceId, - siteId: siteResources.siteId, orgId: siteResources.orgId, niceId: siteResources.niceId, name: siteResources.name, mode: siteResources.mode, - protocol: siteResources.protocol, + ssl: siteResources.ssl, + scheme: siteResources.scheme, proxyPort: siteResources.proxyPort, destinationPort: siteResources.destinationPort, destination: siteResources.destination, @@ -100,12 +136,24 @@ function querySiteResourcesBase() { disableIcmp: siteResources.disableIcmp, authDaemonMode: siteResources.authDaemonMode, authDaemonPort: siteResources.authDaemonPort, - siteName: sites.name, - siteNiceId: sites.niceId, - siteAddress: sites.address + subdomain: siteResources.subdomain, + domainId: siteResources.domainId, + fullDomain: siteResources.fullDomain, + networkId: siteResources.networkId, + defaultNetworkId: siteResources.defaultNetworkId, + siteNames: aggCol(sites.name), + siteNiceIds: aggCol(sites.niceId), + siteIds: aggCol(sites.siteId), + siteAddresses: aggCol<(string | null)[]>(sites.address), + siteOnlines: aggCol(sites.online) }) .from(siteResources) - .innerJoin(sites, eq(siteResources.siteId, sites.siteId)); + .innerJoin( + siteNetworks, + eq(siteResources.networkId, siteNetworks.networkId) + ) + .innerJoin(sites, eq(siteNetworks.siteId, sites.siteId)) + .groupBy(siteResources.siteResourceId); } registry.registerPath({ @@ -193,10 +241,12 @@ export async function listAllSiteResourcesByOrg( const baseQuery = querySiteResourcesBase().where(and(...conditions)); const countQuery = db.$count( - querySiteResourcesBase().where(and(...conditions)).as("filtered_site_resources") + querySiteResourcesBase() + .where(and(...conditions)) + .as("filtered_site_resources") ); - const [siteResourcesList, totalCount] = await Promise.all([ + const [siteResourcesRaw, totalCount] = await Promise.all([ baseQuery .limit(pageSize) .offset(pageSize * (page - 1)) @@ -210,6 +260,8 @@ export async function listAllSiteResourcesByOrg( countQuery ]); + const siteResourcesList = siteResourcesRaw.map(transformSiteResourceRow); + return response(res, { data: { siteResources: siteResourcesList, @@ -233,4 +285,4 @@ export async function listAllSiteResourcesByOrg( ) ); } -} +} \ No newline at end of file diff --git a/server/routers/siteResource/listSiteResources.ts b/server/routers/siteResource/listSiteResources.ts index 358aa0497..8a1469f76 100644 --- a/server/routers/siteResource/listSiteResources.ts +++ b/server/routers/siteResource/listSiteResources.ts @@ -1,6 +1,6 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; +import { db, networks, siteNetworks } from "@server/db"; import { siteResources, sites, SiteResource } from "@server/db"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; @@ -108,13 +108,21 @@ export async function listSiteResources( return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); } - // Get site resources + // Get site resources by joining networks to siteResources via siteNetworks const siteResourcesList = await db .select() - .from(siteResources) + .from(siteNetworks) + .innerJoin( + networks, + eq(siteNetworks.networkId, networks.networkId) + ) + .innerJoin( + siteResources, + eq(siteResources.networkId, networks.networkId) + ) .where( and( - eq(siteResources.siteId, siteId), + eq(siteNetworks.siteId, siteId), eq(siteResources.orgId, orgId) ) ) @@ -128,6 +136,7 @@ export async function listSiteResources( .limit(limit) .offset(offset); + return response(res, { data: { siteResources: siteResourcesList }, success: true, diff --git a/server/routers/siteResource/updateSiteResource.ts b/server/routers/siteResource/updateSiteResource.ts index ab70d0fce..4335b55d3 100644 --- a/server/routers/siteResource/updateSiteResource.ts +++ b/server/routers/siteResource/updateSiteResource.ts @@ -1,4 +1,3 @@ -import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; import { clientSiteResources, clientSiteResourcesAssociationsCache, @@ -7,13 +6,21 @@ import { orgs, roles, roleSiteResources, + siteNetworks, SiteResource, siteResources, sites, + networks, Transaction, userSiteResources } from "@server/db"; -import { tierMatrix } from "@server/lib/billing/tierMatrix"; +import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; +import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix"; +import { validateAndConstructDomain } from "@server/lib/domainUtils"; +import response from "@server/lib/response"; +import { eq, and, ne, inArray } from "drizzle-orm"; +import { OpenAPITags, registry } from "@server/openApi"; +import { updatePeerData, updateTargets } from "@server/routers/client/targets"; import { generateAliasConfig, generateRemoteSubnets, @@ -22,12 +29,8 @@ import { portRangeStringSchema } from "@server/lib/ip"; import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; -import response from "@server/lib/response"; import logger from "@server/logger"; -import { OpenAPITags, registry } from "@server/openApi"; -import { updatePeerData, updateTargets } from "@server/routers/client/targets"; import HttpCode from "@server/types/HttpCode"; -import { and, eq, ne } from "drizzle-orm"; import { NextFunction, Request, Response } from "express"; import createHttpError from "http-errors"; import { z } from "zod"; @@ -40,7 +43,8 @@ const updateSiteResourceParamsSchema = z.strictObject({ const updateSiteResourceSchema = z .strictObject({ name: z.string().min(1).max(255).optional(), - siteId: z.int(), + siteIds: z.array(z.int()), + // niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(), niceId: z .string() .min(1) @@ -51,10 +55,11 @@ const updateSiteResourceSchema = z ) .optional(), // mode: z.enum(["host", "cidr", "port"]).optional(), - mode: z.enum(["host", "cidr"]).optional(), - // protocol: z.enum(["tcp", "udp"]).nullish(), + mode: z.enum(["host", "cidr", "http"]).optional(), + ssl: z.boolean().optional(), + scheme: z.enum(["http", "https"]).nullish(), // proxyPort: z.int().positive().nullish(), - // destinationPort: z.int().positive().nullish(), + destinationPort: z.int().positive().nullish(), destination: z.string().min(1).optional(), enabled: z.boolean().optional(), alias: z @@ -71,7 +76,9 @@ const updateSiteResourceSchema = z udpPortRangeString: portRangeStringSchema, disableIcmp: z.boolean().optional(), authDaemonPort: z.int().positive().nullish(), - authDaemonMode: z.enum(["site", "remote"]).optional() + authDaemonMode: z.enum(["site", "remote"]).optional(), + domainId: z.string().optional(), + subdomain: z.string().optional() }) .strict() .refine( @@ -118,6 +125,23 @@ const updateSiteResourceSchema = z { message: "Destination must be a valid CIDR notation for cidr mode" } + ) + .refine( + (data) => { + if (data.mode !== "http") return true; + return ( + data.scheme !== undefined && + data.scheme !== null && + data.destinationPort !== undefined && + data.destinationPort !== null && + data.destinationPort >= 1 && + data.destinationPort <= 65535 + ); + }, + { + message: + "HTTP mode requires scheme (http or https) and a valid destination port" + } ); export type UpdateSiteResourceBody = z.infer; @@ -172,11 +196,14 @@ export async function updateSiteResource( const { siteResourceId } = parsedParams.data; const { name, - siteId, // because it can change + siteIds, // because it can change niceId, mode, + scheme, destination, + destinationPort, alias, + ssl, enabled, userIds, roleIds, @@ -185,19 +212,11 @@ export async function updateSiteResource( udpPortRangeString, disableIcmp, authDaemonPort, - authDaemonMode + authDaemonMode, + domainId, + subdomain } = parsedBody.data; - const [site] = await db - .select() - .from(sites) - .where(eq(sites.siteId, siteId)) - .limit(1); - - if (!site) { - return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); - } - // Check if site resource exists const [existingSiteResource] = await db .select() @@ -211,6 +230,21 @@ export async function updateSiteResource( ); } + if (mode == "http") { + const hasHttpFeature = await isLicensedOrSubscribed( + existingSiteResource.orgId, + tierMatrix[TierFeature.HTTPPrivateResources] + ); + if (!hasHttpFeature) { + return next( + createHttpError( + HttpCode.FORBIDDEN, + "HTTP private resources are not included in your current plan. Please upgrade." + ) + ); + } + } + const isLicensedSshPam = await isLicensedOrSubscribed( existingSiteResource.orgId, tierMatrix.sshPam @@ -237,6 +271,23 @@ export async function updateSiteResource( ); } + // Verify the site exists and belongs to the org + const sitesToAssign = await db + .select() + .from(sites) + .where( + and( + inArray(sites.siteId, siteIds), + eq(sites.orgId, existingSiteResource.orgId) + ) + ); + + if (sitesToAssign.length !== siteIds.length) { + return next( + createHttpError(HttpCode.NOT_FOUND, "Some site not found") + ); + } + // Only check if destination is an IP address const isIp = z .union([z.ipv4(), z.ipv6()]) @@ -254,22 +305,60 @@ export async function updateSiteResource( ); } - let existingSite = site; - let siteChanged = false; - if (existingSiteResource.siteId !== siteId) { - siteChanged = true; - // get the existing site - [existingSite] = await db - .select() - .from(sites) - .where(eq(sites.siteId, existingSiteResource.siteId)) - .limit(1); + let sitesChanged = false; + const existingSiteIds = existingSiteResource.networkId + ? await db + .select() + .from(siteNetworks) + .where( + eq(siteNetworks.networkId, existingSiteResource.networkId) + ) + : []; - if (!existingSite) { + const existingSiteIdSet = new Set(existingSiteIds.map((s) => s.siteId)); + const newSiteIdSet = new Set(siteIds); + + if ( + existingSiteIdSet.size !== newSiteIdSet.size || + ![...existingSiteIdSet].every((id) => newSiteIdSet.has(id)) + ) { + sitesChanged = true; + } + + let fullDomain: string | null = null; + let finalSubdomain: string | null = null; + if (domainId) { + // Validate domain and construct full domain + const domainResult = await validateAndConstructDomain( + domainId, + org.orgId, + subdomain + ); + + if (!domainResult.success) { + return next( + createHttpError(HttpCode.BAD_REQUEST, domainResult.error) + ); + } + + fullDomain = domainResult.fullDomain; + finalSubdomain = domainResult.subdomain; + + // make sure the full domain is unique + const [existingDomain] = await db + .select() + .from(siteResources) + .where(eq(siteResources.fullDomain, fullDomain)); + + if ( + existingDomain && + existingDomain.siteResourceId !== + existingSiteResource.siteResourceId + ) { return next( createHttpError( - HttpCode.NOT_FOUND, - "Existing site not found" + HttpCode.CONFLICT, + "Resource with that domain already exists" ) ); } @@ -302,7 +391,7 @@ export async function updateSiteResource( let updatedSiteResource: SiteResource | undefined; await db.transaction(async (trx) => { // if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place - if (siteChanged) { + if (sitesChanged) { // delete the existing site resource await trx .delete(siteResources) @@ -343,15 +432,20 @@ export async function updateSiteResource( .update(siteResources) .set({ name, - siteId, niceId, mode, + scheme, + ssl, destination, + destinationPort, enabled, - alias: alias && alias.trim() ? alias : null, + alias: alias ? alias.trim() : null, tcpPortRangeString, udpPortRangeString, disableIcmp, + domainId, + subdomain: finalSubdomain, + fullDomain, ...sshPamSet }) .where( @@ -372,6 +466,23 @@ export async function updateSiteResource( //////////////////// update the associations //////////////////// + // delete the site - site resources associations + await trx + .delete(siteNetworks) + .where( + eq( + siteNetworks.networkId, + updatedSiteResource.networkId! + ) + ); + + for (const siteId of siteIds) { + await trx.insert(siteNetworks).values({ + siteId: siteId, + networkId: updatedSiteResource.networkId! + }); + } + const [adminRole] = await trx .select() .from(roles) @@ -447,14 +558,20 @@ export async function updateSiteResource( .update(siteResources) .set({ name: name, - siteId: siteId, + niceId: niceId, mode: mode, + scheme, + ssl, destination: destination, + destinationPort: destinationPort, enabled: enabled, - alias: alias && alias.trim() ? alias : null, + alias: alias ? alias.trim() : null, tcpPortRangeString: tcpPortRangeString, udpPortRangeString: udpPortRangeString, disableIcmp: disableIcmp, + domainId, + subdomain: finalSubdomain, + fullDomain, ...sshPamSet }) .where( @@ -464,6 +581,23 @@ export async function updateSiteResource( //////////////////// update the associations //////////////////// + // delete the site - site resources associations + await trx + .delete(siteNetworks) + .where( + eq( + siteNetworks.networkId, + updatedSiteResource.networkId! + ) + ); + + for (const siteId of siteIds) { + await trx.insert(siteNetworks).values({ + siteId: siteId, + networkId: updatedSiteResource.networkId! + }); + } + await trx .delete(clientSiteResources) .where( @@ -533,14 +667,15 @@ export async function updateSiteResource( ); } - logger.info( - `Updated site resource ${siteResourceId} for site ${siteId}` - ); + logger.info(`Updated site resource ${siteResourceId}`); await handleMessagingForUpdatedSiteResource( existingSiteResource, updatedSiteResource, - { siteId: site.siteId, orgId: site.orgId }, + siteIds.map((siteId) => ({ + siteId, + orgId: existingSiteResource.orgId + })), trx ); } @@ -567,7 +702,7 @@ export async function updateSiteResource( export async function handleMessagingForUpdatedSiteResource( existingSiteResource: SiteResource | undefined, updatedSiteResource: SiteResource, - site: { siteId: number; orgId: string }, + sites: { siteId: number; orgId: string }[], trx: Transaction ) { logger.debug( @@ -589,9 +724,14 @@ export async function handleMessagingForUpdatedSiteResource( const destinationChanged = existingSiteResource && existingSiteResource.destination !== updatedSiteResource.destination; + const destinationPortChanged = + existingSiteResource && + existingSiteResource.destinationPort !== + updatedSiteResource.destinationPort; const aliasChanged = existingSiteResource && - existingSiteResource.alias !== updatedSiteResource.alias; + (existingSiteResource.alias !== updatedSiteResource.alias || + existingSiteResource.fullDomain !== updatedSiteResource.fullDomain); // because the full domain gets sent down to the stuff as an alias const portRangesChanged = existingSiteResource && (existingSiteResource.tcpPortRangeString !== @@ -603,106 +743,122 @@ export async function handleMessagingForUpdatedSiteResource( // if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all - if (destinationChanged || aliasChanged || portRangesChanged) { - const [newt] = await trx - .select() - .from(newts) - .where(eq(newts.siteId, site.siteId)) - .limit(1); - - if (!newt) { - throw new Error( - "Newt not found for site during site resource update" - ); - } - - // Only update targets on newt if destination changed - if (destinationChanged || portRangesChanged) { - const oldTargets = generateSubnetProxyTargetV2( - existingSiteResource, - mergedAllClients - ); - const newTargets = generateSubnetProxyTargetV2( - updatedSiteResource, - mergedAllClients - ); - - await updateTargets( - newt.newtId, - { - oldTargets: oldTargets ? oldTargets : [], - newTargets: newTargets ? newTargets : [] - }, - newt.version - ); - } - - const olmJobs: Promise[] = []; - for (const client of mergedAllClients) { - // does this client have access to another resource on this site that has the same destination still? if so we dont want to remove it from their olm yet - // todo: optimize this query if needed - const oldDestinationStillInUseSites = await trx + if ( + destinationChanged || + aliasChanged || + portRangesChanged || + destinationPortChanged + ) { + for (const site of sites) { + const [newt] = await trx .select() - .from(siteResources) - .innerJoin( - clientSiteResourcesAssociationsCache, - eq( - clientSiteResourcesAssociationsCache.siteResourceId, - siteResources.siteResourceId - ) - ) - .where( - and( - eq( - clientSiteResourcesAssociationsCache.clientId, - client.clientId - ), - eq(siteResources.siteId, site.siteId), - eq( - siteResources.destination, - existingSiteResource.destination - ), - ne( - siteResources.siteResourceId, - existingSiteResource.siteResourceId - ) - ) + .from(newts) + .where(eq(newts.siteId, site.siteId)) + .limit(1); + + if (!newt) { + throw new Error( + "Newt not found for site during site resource update" + ); + } + + // Only update targets on newt if destination changed + if ( + destinationChanged || + portRangesChanged || + destinationPortChanged + ) { + const oldTargets = await generateSubnetProxyTargetV2( + existingSiteResource, + mergedAllClients + ); + const newTargets = await generateSubnetProxyTargetV2( + updatedSiteResource, + mergedAllClients ); - const oldDestinationStillInUseByASite = - oldDestinationStillInUseSites.length > 0; + await updateTargets( + newt.newtId, + { + oldTargets: oldTargets ? oldTargets : [], + newTargets: newTargets ? newTargets : [] + }, + newt.version + ); + } - // we also need to update the remote subnets on the olms for each client that has access to this site - olmJobs.push( - updatePeerData( - client.clientId, - updatedSiteResource.siteId, - destinationChanged - ? { - oldRemoteSubnets: !oldDestinationStillInUseByASite - ? generateRemoteSubnets([ - existingSiteResource - ]) - : [], - newRemoteSubnets: generateRemoteSubnets([ - updatedSiteResource - ]) - } - : undefined, - aliasChanged - ? { - oldAliases: generateAliasConfig([ - existingSiteResource - ]), - newAliases: generateAliasConfig([ - updatedSiteResource - ]) - } - : undefined - ) - ); + const olmJobs: Promise[] = []; + for (const client of mergedAllClients) { + // does this client have access to another resource on this site that has the same destination still? if so we dont want to remove it from their olm yet + // todo: optimize this query if needed + const oldDestinationStillInUseSites = await trx + .select() + .from(siteResources) + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + clientSiteResourcesAssociationsCache.siteResourceId, + siteResources.siteResourceId + ) + ) + .innerJoin( + siteNetworks, + eq(siteNetworks.networkId, siteResources.networkId) + ) + .where( + and( + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ), + eq(siteNetworks.siteId, site.siteId), + eq( + siteResources.destination, + existingSiteResource.destination + ), + ne( + siteResources.siteResourceId, + existingSiteResource.siteResourceId + ) + ) + ); + + const oldDestinationStillInUseByASite = + oldDestinationStillInUseSites.length > 0; + + // we also need to update the remote subnets on the olms for each client that has access to this site + olmJobs.push( + updatePeerData( + client.clientId, + site.siteId, + destinationChanged + ? { + oldRemoteSubnets: + !oldDestinationStillInUseByASite + ? generateRemoteSubnets([ + existingSiteResource + ]) + : [], + newRemoteSubnets: generateRemoteSubnets([ + updatedSiteResource + ]) + } + : undefined, + aliasChanged + ? { + oldAliases: generateAliasConfig([ + existingSiteResource + ]), + newAliases: generateAliasConfig([ + updatedSiteResource + ]) + } + : undefined + ) + ); + } + + await Promise.all(olmJobs); } - - await Promise.all(olmJobs); } } diff --git a/server/routers/ws/messageHandlers.ts b/server/routers/ws/messageHandlers.ts index 143e4d516..2dc09eedc 100644 --- a/server/routers/ws/messageHandlers.ts +++ b/server/routers/ws/messageHandlers.ts @@ -2,7 +2,7 @@ import { build } from "@server/build"; import { handleNewtRegisterMessage, handleReceiveBandwidthMessage, - handleGetConfigMessage, + handleNewtGetConfigMessage, handleDockerStatusMessage, handleDockerContainersMessage, handleNewtPingRequestMessage, @@ -37,7 +37,7 @@ export const messageHandlers: Record = { "newt/disconnecting": handleNewtDisconnectingMessage, "newt/ping": handleNewtPingMessage, "newt/wg/register": handleNewtRegisterMessage, - "newt/wg/get-config": handleGetConfigMessage, + "newt/wg/get-config": handleNewtGetConfigMessage, "newt/receive-bandwidth": handleReceiveBandwidthMessage, "newt/socket/status": handleDockerStatusMessage, "newt/socket/containers": handleDockerContainersMessage, diff --git a/src/app/[orgId]/settings/logs/access/page.tsx b/src/app/[orgId]/settings/logs/access/page.tsx index a0f1b5386..826e11c17 100644 --- a/src/app/[orgId]/settings/logs/access/page.tsx +++ b/src/app/[orgId]/settings/logs/access/page.tsx @@ -471,11 +471,7 @@ export default function GeneralPage() { : `/${row.original.orgId}/settings/resources/proxy/${row.original.resourceNiceId}` } > - diff --git a/src/app/[orgId]/settings/logs/connection/page.tsx b/src/app/[orgId]/settings/logs/connection/page.tsx index e15708f8e..6eaedff5a 100644 --- a/src/app/[orgId]/settings/logs/connection/page.tsx +++ b/src/app/[orgId]/settings/logs/connection/page.tsx @@ -451,11 +451,7 @@ export default function ConnectionLogsPage() { - @@ -497,11 +493,7 @@ export default function ConnectionLogsPage() { - @@ -634,6 +636,7 @@ export default function GeneralPage() { { value: "105", label: t("validPassword") }, { value: "106", label: t("validEmail") }, { value: "107", label: t("validSSO") }, + { value: "108", label: t("connectedClient") }, { value: "201", label: t("resourceNotFound") }, { value: "202", label: t("resourceBlocked") }, { value: "203", label: t("droppedByRule") }, diff --git a/src/app/[orgId]/settings/resources/client/page.tsx b/src/app/[orgId]/settings/resources/client/page.tsx index f0f582f0f..c15e3d429 100644 --- a/src/app/[orgId]/settings/resources/client/page.tsx +++ b/src/app/[orgId]/settings/resources/client/page.tsx @@ -60,23 +60,34 @@ export default async function ClientResourcesPage( id: siteResource.siteResourceId, name: siteResource.name, orgId: params.orgId, - siteName: siteResource.siteName, - siteAddress: siteResource.siteAddress || null, - mode: siteResource.mode || ("port" as any), + sites: siteResource.siteIds.map((siteId, idx) => ({ + siteId, + siteName: siteResource.siteNames[idx], + siteNiceId: siteResource.siteNiceIds[idx], + online: siteResource.siteOnlines[idx] + })), + mode: siteResource.mode, + scheme: siteResource.scheme, + ssl: siteResource.ssl, + siteNames: siteResource.siteNames, + siteAddresses: siteResource.siteAddresses || null, // protocol: siteResource.protocol, // proxyPort: siteResource.proxyPort, - siteId: siteResource.siteId, + siteIds: siteResource.siteIds, destination: siteResource.destination, - // destinationPort: siteResource.destinationPort, + httpHttpsPort: siteResource.destinationPort ?? null, alias: siteResource.alias || null, aliasAddress: siteResource.aliasAddress || null, - siteNiceId: siteResource.siteNiceId, + siteNiceIds: siteResource.siteNiceIds, niceId: siteResource.niceId, tcpPortRangeString: siteResource.tcpPortRangeString || null, udpPortRangeString: siteResource.udpPortRangeString || null, disableIcmp: siteResource.disableIcmp || false, authDaemonMode: siteResource.authDaemonMode ?? null, - authDaemonPort: siteResource.authDaemonPort ?? null + authDaemonPort: siteResource.authDaemonPort ?? null, + subdomain: siteResource.subdomain ?? null, + domainId: siteResource.domainId ?? null, + fullDomain: siteResource.fullDomain ?? null }; } ); diff --git a/src/app/private-maintenance-screen/page.tsx b/src/app/private-maintenance-screen/page.tsx new file mode 100644 index 000000000..21417b6f4 --- /dev/null +++ b/src/app/private-maintenance-screen/page.tsx @@ -0,0 +1,32 @@ +import { Metadata } from "next"; +import { getTranslations } from "next-intl/server"; +import { + Card, + CardContent, + CardHeader, + CardTitle +} from "@app/components/ui/card"; + +export const dynamic = "force-dynamic"; + +export const metadata: Metadata = { + title: "Private Placeholder" +}; + +export default async function MaintenanceScreen() { + const t = await getTranslations(); + + let title = t("privateMaintenanceScreenTitle"); + let message = t("privateMaintenanceScreenMessage"); + + return ( +
+ + + {title} + + {message} + +
+ ); +} diff --git a/src/components/ClientResourcesTable.tsx b/src/components/ClientResourcesTable.tsx index 5066f273d..c32208321 100644 --- a/src/components/ClientResourcesTable.tsx +++ b/src/components/ClientResourcesTable.tsx @@ -21,6 +21,7 @@ import { ArrowUp10Icon, ArrowUpDown, ArrowUpRight, + ChevronDown, ChevronsUpDownIcon, MoreHorizontal } from "lucide-react"; @@ -38,21 +39,32 @@ import { ControlledDataTable } from "./ui/controlled-data-table"; import { useNavigationContext } from "@app/hooks/useNavigationContext"; import { useDebouncedCallback } from "use-debounce"; import { ColumnFilterButton } from "./ColumnFilterButton"; +import { cn } from "@app/lib/cn"; + +export type InternalResourceSiteRow = { + siteId: number; + siteName: string; + siteNiceId: string; + online: boolean; +}; export type InternalResourceRow = { id: number; name: string; orgId: string; - siteName: string; - siteAddress: string | null; + sites: InternalResourceSiteRow[]; + siteNames: string[]; + siteAddresses: (string | null)[]; + siteIds: number[]; + siteNiceIds: string[]; // mode: "host" | "cidr" | "port"; - mode: "host" | "cidr"; + mode: "host" | "cidr" | "http"; + scheme: "http" | "https" | null; + ssl: boolean; // protocol: string | null; // proxyPort: number | null; - siteId: number; - siteNiceId: string; destination: string; - // destinationPort: number | null; + httpHttpsPort: number | null; alias: string | null; aliasAddress: string | null; niceId: string; @@ -61,8 +73,147 @@ export type InternalResourceRow = { disableIcmp: boolean; authDaemonMode?: "site" | "remote" | null; authDaemonPort?: number | null; + subdomain?: string | null; + domainId?: string | null; + fullDomain?: string | null; }; +function resolveHttpHttpsDisplayPort( + mode: "http", + httpHttpsPort: number | null +): number { + if (httpHttpsPort != null) { + return httpHttpsPort; + } + return 80; +} + +function formatDestinationDisplay(row: InternalResourceRow): string { + const { mode, destination, httpHttpsPort, scheme } = row; + if (mode !== "http") { + return destination; + } + const port = resolveHttpHttpsDisplayPort(mode, httpHttpsPort); + const downstreamScheme = scheme ?? "http"; + const hostPart = + destination.includes(":") && !destination.startsWith("[") + ? `[${destination}]` + : destination; + return `${downstreamScheme}://${hostPart}:${port}`; +} + +function isSafeUrlForLink(href: string): boolean { + try { + void new URL(href); + return true; + } catch { + return false; + } +} + +type AggregateSitesStatus = "allOnline" | "partial" | "allOffline"; + +function aggregateSitesStatus( + resourceSites: InternalResourceSiteRow[] +): AggregateSitesStatus { + if (resourceSites.length === 0) { + return "allOffline"; + } + const onlineCount = resourceSites.filter((rs) => rs.online).length; + if (onlineCount === resourceSites.length) return "allOnline"; + if (onlineCount > 0) return "partial"; + return "allOffline"; +} + +function aggregateStatusDotClass(status: AggregateSitesStatus): string { + switch (status) { + case "allOnline": + return "bg-green-500"; + case "partial": + return "bg-yellow-500"; + case "allOffline": + default: + return "bg-gray-500"; + } +} + +function ClientResourceSitesStatusCell({ + orgId, + resourceSites +}: { + orgId: string; + resourceSites: InternalResourceSiteRow[]; +}) { + const t = useTranslations(); + + if (resourceSites.length === 0) { + return -; + } + + const aggregate = aggregateSitesStatus(resourceSites); + const countLabel = t("multiSitesSelectorSitesCount", { + count: resourceSites.length + }); + + return ( + + + + + + {resourceSites.map((site) => { + const isOnline = site.online; + return ( + + +
+
+ + {site.siteName} + +
+ + {isOnline ? t("online") : t("offline")} + + + + ); + })} + + + ); +} + type ClientResourcesTableProps = { internalResources: InternalResourceRow[]; orgId: string; @@ -97,8 +248,6 @@ export default function ClientResourcesTable({ useState(); const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false); - const { data: sites = [] } = useQuery(orgQueries.sites({ orgId })); - const [isRefreshing, startTransition] = useTransition(); const refreshData = () => { @@ -136,6 +285,60 @@ export default function ClientResourcesTable({ } }; + function SiteCell({ resourceRow }: { resourceRow: InternalResourceRow }) { + const { siteNames, siteNiceIds, orgId } = resourceRow; + + if (!siteNames || siteNames.length === 0) { + return -; + } + + if (siteNames.length === 1) { + return ( + + + + ); + } + + return ( + + + + + + {siteNames.map((siteName, idx) => ( + + + {siteName} + + + + ))} + + + ); + } + const internalColumns: ExtendedColumnDef[] = [ { accessorKey: "name", @@ -185,20 +388,17 @@ export default function ClientResourcesTable({ } }, { - accessorKey: "siteName", - friendlyName: t("site"), - header: () => {t("site")}, + id: "sites", + accessorFn: (row) => row.sites.map((s) => s.siteName).join(", "), + friendlyName: t("sites"), + header: () => {t("sites")}, cell: ({ row }) => { const resourceRow = row.original; return ( - - - + ); } }, @@ -215,6 +415,10 @@ export default function ClientResourcesTable({ { value: "cidr", label: t("editInternalResourceDialogModeCidr") + }, + { + value: "http", + label: t("editInternalResourceDialogModeHttp") } ]} selectedValue={searchParams.get("mode") ?? undefined} @@ -227,10 +431,14 @@ export default function ClientResourcesTable({ ), cell: ({ row }) => { const resourceRow = row.original; - const modeLabels: Record<"host" | "cidr" | "port", string> = { + const modeLabels: Record< + "host" | "cidr" | "port" | "http", + string + > = { host: t("editInternalResourceDialogModeHost"), cidr: t("editInternalResourceDialogModeCidr"), - port: t("editInternalResourceDialogModePort") + port: t("editInternalResourceDialogModePort"), + http: t("editInternalResourceDialogModeHttp") }; return {modeLabels[resourceRow.mode]}; } @@ -243,11 +451,12 @@ export default function ClientResourcesTable({ ), cell: ({ row }) => { const resourceRow = row.original; + const display = formatDestinationDisplay(resourceRow); return ( ); } @@ -260,15 +469,26 @@ export default function ClientResourcesTable({ ), cell: ({ row }) => { const resourceRow = row.original; - return resourceRow.mode === "host" && resourceRow.alias ? ( - - ) : ( - - - ); + if (resourceRow.mode === "host" && resourceRow.alias) { + return ( + + ); + } + if (resourceRow.mode === "http") { + const url = `${resourceRow.ssl ? "https" : "http"}://${resourceRow.fullDomain}`; + return ( + + ); + } + return -; } }, { @@ -399,7 +619,7 @@ export default function ClientResourcesTable({ onConfirm={async () => deleteInternalResource( selectedInternalResource!.id, - selectedInternalResource!.siteId + selectedInternalResource!.siteIds[0] ) } string={selectedInternalResource.name} @@ -435,7 +655,6 @@ export default function ClientResourcesTable({ setOpen={setIsEditDialogOpen} resource={editingResource} orgId={orgId} - sites={sites} onSuccess={() => { // Delay refresh to allow modal to close smoothly setTimeout(() => { @@ -450,7 +669,6 @@ export default function ClientResourcesTable({ open={isCreateDialogOpen} setOpen={setIsCreateDialogOpen} orgId={orgId} - sites={sites} onSuccess={() => { // Delay refresh to allow modal to close smoothly setTimeout(() => { diff --git a/src/components/CreateInternalResourceDialog.tsx b/src/components/CreateInternalResourceDialog.tsx index d5ca61acc..4d2bc0916 100644 --- a/src/components/CreateInternalResourceDialog.tsx +++ b/src/components/CreateInternalResourceDialog.tsx @@ -14,7 +14,6 @@ import { Button } from "@app/components/ui/button"; import { useEnvContext } from "@app/hooks/useEnvContext"; import { toast } from "@app/hooks/useToast"; import { createApiClient, formatAxiosError } from "@app/lib/api"; -import { ListSitesResponse } from "@server/routers/site"; import { AxiosResponse } from "axios"; import { useTranslations } from "next-intl"; import { useState } from "react"; @@ -25,13 +24,10 @@ import { type InternalResourceFormValues } from "./InternalResourceForm"; -type Site = ListSitesResponse["sites"][0]; - type CreateInternalResourceDialogProps = { open: boolean; setOpen: (val: boolean) => void; orgId: string; - sites: Site[]; onSuccess?: () => void; }; @@ -39,18 +35,21 @@ export default function CreateInternalResourceDialog({ open, setOpen, orgId, - sites, onSuccess }: CreateInternalResourceDialogProps) { const t = useTranslations(); const api = createApiClient(useEnvContext()); const [isSubmitting, setIsSubmitting] = useState(false); + const [isHttpModeDisabled, setIsHttpModeDisabled] = useState(false); async function handleSubmit(values: InternalResourceFormValues) { setIsSubmitting(true); try { let data = { ...values }; - if (data.mode === "host" && isHostname(data.destination)) { + if ( + (data.mode === "host" || data.mode === "http") && + isHostname(data.destination) + ) { const currentAlias = data.alias?.trim() || ""; if (!currentAlias) { let aliasValue = data.destination; @@ -65,25 +64,56 @@ export default function CreateInternalResourceDialog({ `/org/${orgId}/site-resource`, { name: data.name, - siteId: data.siteId, + siteIds: data.siteIds, mode: data.mode, destination: data.destination, enabled: true, - alias: data.alias && typeof data.alias === "string" && data.alias.trim() ? data.alias : undefined, - tcpPortRangeString: data.tcpPortRangeString, - udpPortRangeString: data.udpPortRangeString, - disableIcmp: data.disableIcmp ?? false, - ...(data.authDaemonMode != null && { authDaemonMode: data.authDaemonMode }), - ...(data.authDaemonMode === "remote" && data.authDaemonPort != null && { authDaemonPort: data.authDaemonPort }), - roleIds: data.roles ? data.roles.map((r) => parseInt(r.id)) : [], + ...(data.mode === "http" && { + scheme: data.scheme, + ssl: data.ssl ?? false, + destinationPort: data.httpHttpsPort ?? undefined, + domainId: data.httpConfigDomainId + ? data.httpConfigDomainId + : undefined, + subdomain: data.httpConfigSubdomain + ? data.httpConfigSubdomain + : undefined + }), + ...(data.mode === "host" && { + alias: + data.alias && + typeof data.alias === "string" && + data.alias.trim() + ? data.alias + : undefined, + ...(data.authDaemonMode != null && { + authDaemonMode: data.authDaemonMode + }), + ...(data.authDaemonMode === "remote" && + data.authDaemonPort != null && { + authDaemonPort: data.authDaemonPort + }) + }), + ...((data.mode === "host" || data.mode == "cidr") && { + tcpPortRangeString: data.tcpPortRangeString, + udpPortRangeString: data.udpPortRangeString, + disableIcmp: data.disableIcmp ?? false + }), + roleIds: data.roles + ? data.roles.map((r) => parseInt(r.id)) + : [], userIds: data.users ? data.users.map((u) => u.id) : [], - clientIds: data.clients ? data.clients.map((c) => parseInt(c.id)) : [] + clientIds: data.clients + ? data.clients.map((c) => parseInt(c.id)) + : [] } ); toast({ title: t("createInternalResourceDialogSuccess"), - description: t("createInternalResourceDialogInternalResourceCreatedSuccessfully"), + description: t( + "createInternalResourceDialogInternalResourceCreatedSuccessfully" + ), variant: "default" }); setOpen(false); @@ -93,7 +123,9 @@ export default function CreateInternalResourceDialog({ title: t("createInternalResourceDialogError"), description: formatAxiosError( error, - t("createInternalResourceDialogFailedToCreateInternalResource") + t( + "createInternalResourceDialogFailedToCreateInternalResource" + ) ), variant: "destructive" }); @@ -106,31 +138,39 @@ export default function CreateInternalResourceDialog({ - {t("createInternalResourceDialogCreateClientResource")} + + {t("createInternalResourceDialogCreateClientResource")} + - {t("createInternalResourceDialogCreateClientResourceDescription")} + {t( + "createInternalResourceDialogCreateClientResourceDescription" + )} - - - - - { - setSelectedSite(site); - field.onChange(site.siteId); - }} - /> - - - - - )} - />
-
-
+
+
( - + - {t(modeLabelKey)} + {t("sites")} - + + + + + + + + { + setSelectedSites( + sites + ); + field.onChange( + sites.map( + (s) => + s.siteId + ) + ); + }} + /> + + )} />
+
+ { + const modeOptions: OptionSelectOption[] = + [ + { + value: "host", + label: t(modeHostKey) + }, + { + value: "cidr", + label: t(modeCidrKey) + }, + { + value: "http", + label: t(modeHttpKey) + } + ]; + return ( + + + {t(modeLabelKey)} + + + options={modeOptions} + value={field.value} + onChange={ + field.onChange + } + cols={3} + /> + + + ); + }} + /> +
+
+
+ {mode === "http" && ( +
+ ( + + + {t(schemeLabelKey)} + + + + + )} + /> +
+ )}
- + )} />
- {mode !== "cidr" && ( -
+ {mode === "host" && ( +
)} + {mode === "http" && ( +
+ ( + + + {t( + httpHttpsPortLabelKey + )} + + + { + const raw = + e.target + .value; + if ( + raw === "" + ) { + field.onChange( + null + ); + return; + } + const n = + Number(raw); + field.onChange( + Number.isFinite( + n + ) + ? n + : null + ); + }} + /> + + + + )} + /> +
+ )}
-
-
- -
- {t( - "editInternalResourceDialogPortRestrictionsDescription" + {isHttpMode && ( + + )} + + {isHttpMode ? ( +
+
+ +
+ {t(httpConfigurationDescriptionKey)} +
+
+
+ { + if (res === null) { + form.setValue( + "httpConfigSubdomain", + null + ); + form.setValue( + "httpConfigDomainId", + null + ); + form.setValue( + "httpConfigFullDomain", + null + ); + return; + } + form.setValue( + "httpConfigSubdomain", + res.subdomain ?? null + ); + form.setValue( + "httpConfigDomainId", + res.domainId + ); + form.setValue( + "httpConfigFullDomain", + res.fullDomain + ); + }} + /> +
+ ( + + + + + )} -
+ />
-
-
- - {t("editInternalResourceDialogTcp")} - -
-
- ( - -
- - {tcpPortMode === - "custom" ? ( - - - setTcpCustomPorts( - e.target - .value - ) - } - /> - - ) : ( - - )} -
- -
+ ) : ( +
+
+ +
+ {t( + "editInternalResourceDialogPortRestrictionsDescription" )} - /> -
-
-
-
- - {t("editInternalResourceDialogUdp")} - +
- ( - -
- - {udpPortMode === - "custom" ? ( - - - setUdpCustomPorts( - e.target - .value - ) - } - /> - - ) : ( - - )} -
- -
- )} - /> -
-
-
-
- - {t("editInternalResourceDialogIcmp")} - -
-
- ( - -
- - + + {t("editInternalResourceDialogTcp")} + +
+
+ ( + +
+ + {tcpPortMode === + "custom" ? ( + + + setTcpCustomPorts( + e + .target + .value + ) + } + /> + + ) : ( + + )} +
+ +
+ )} + /> +
+
+
+
+ + {t("editInternalResourceDialogUdp")} + +
+
+ ( + +
+ + {udpPortMode === + "custom" ? ( + + + setUdpCustomPorts( + e + .target + .value + ) + } + /> + + ) : ( + + )} +
+ +
+ )} + /> +
+
+
+
+ + {t( + "editInternalResourceDialogIcmp" + )} + +
+
+ ( + +
+ + + field.onChange( + !checked + ) + } + /> + + + {field.value + ? t("blocked") + : t("allowed")} + +
+ +
+ )} + /> +
-
+ )}
@@ -1213,8 +1579,8 @@ export function InternalResourceForm({ )}
- {/* SSH Access tab */} - {!disableEnterpriseFeatures && mode !== "cidr" && ( + {/* SSH Access tab (host mode only) */} + {!disableEnterpriseFeatures && mode === "host" && (
diff --git a/src/components/LogDataTable.tsx b/src/components/LogDataTable.tsx index 3a53a859f..14e87ff75 100644 --- a/src/components/LogDataTable.tsx +++ b/src/components/LogDataTable.tsx @@ -405,7 +405,11 @@ export function LogDataTable({ onClick={() => !disabled && onExport() } - disabled={isExporting || disabled || isExportDisabled} + disabled={ + isExporting || + disabled || + isExportDisabled + } > {isExporting ? ( diff --git a/src/components/PendingSitesTable.tsx b/src/components/PendingSitesTable.tsx index c65cb218e..a1ed6f354 100644 --- a/src/components/PendingSitesTable.tsx +++ b/src/components/PendingSitesTable.tsx @@ -353,9 +353,9 @@ export default function PendingSitesTable({ - ); diff --git a/src/components/ShareLinksTable.tsx b/src/components/ShareLinksTable.tsx index efac77df3..333cee03f 100644 --- a/src/components/ShareLinksTable.tsx +++ b/src/components/ShareLinksTable.tsx @@ -144,9 +144,9 @@ export default function ShareLinksTable({ - ); diff --git a/src/components/SitesTable.tsx b/src/components/SitesTable.tsx index 6cca706a6..606630a50 100644 --- a/src/components/SitesTable.tsx +++ b/src/components/SitesTable.tsx @@ -363,9 +363,9 @@ export default function SitesTable({ - ); diff --git a/src/components/UserDevicesTable.tsx b/src/components/UserDevicesTable.tsx index 4c5331015..5542029a6 100644 --- a/src/components/UserDevicesTable.tsx +++ b/src/components/UserDevicesTable.tsx @@ -373,12 +373,12 @@ export default function UserDevicesTable({ - ) : ( diff --git a/src/components/WorldMap.tsx b/src/components/WorldMap.tsx index ac227c553..09548400b 100644 --- a/src/components/WorldMap.tsx +++ b/src/components/WorldMap.tsx @@ -218,7 +218,7 @@ function drawInteractiveCountries( }); hoverPath .datum(country) - .attr("d", path(country) as string) + .attr("d", path(country as any) as string) .style("display", null); }) diff --git a/src/components/multi-site-selector.tsx b/src/components/multi-site-selector.tsx new file mode 100644 index 000000000..407e3b3e1 --- /dev/null +++ b/src/components/multi-site-selector.tsx @@ -0,0 +1,117 @@ +import { orgQueries } from "@app/lib/queries"; +import { useQuery } from "@tanstack/react-query"; +import { useMemo, useState } from "react"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList +} from "./ui/command"; +import { Checkbox } from "./ui/checkbox"; +import { useTranslations } from "next-intl"; +import { useDebounce } from "use-debounce"; +import type { Selectedsite } from "./site-selector"; + +export type MultiSitesSelectorProps = { + orgId: string; + selectedSites: Selectedsite[]; + onSelectionChange: (sites: Selectedsite[]) => void; + filterTypes?: string[]; +}; + +export function formatMultiSitesSelectorLabel( + selectedSites: Selectedsite[], + t: (key: string, values?: { count: number }) => string +): string { + if (selectedSites.length === 0) { + return t("selectSites"); + } + if (selectedSites.length === 1) { + return selectedSites[0]!.name; + } + return t("multiSitesSelectorSitesCount", { + count: selectedSites.length + }); +} + +export function MultiSitesSelector({ + orgId, + selectedSites, + onSelectionChange, + filterTypes +}: MultiSitesSelectorProps) { + const t = useTranslations(); + const [siteSearchQuery, setSiteSearchQuery] = useState(""); + const [debouncedQuery] = useDebounce(siteSearchQuery, 150); + + const { data: sites = [] } = useQuery( + orgQueries.sites({ + orgId, + query: debouncedQuery, + perPage: 10 + }) + ); + + const sitesShown = useMemo(() => { + const base = filterTypes + ? sites.filter((s) => filterTypes.includes(s.type)) + : [...sites]; + if (debouncedQuery.trim().length === 0 && selectedSites.length > 0) { + const selectedNotInBase = selectedSites.filter( + (sel) => !base.some((s) => s.siteId === sel.siteId) + ); + return [...selectedNotInBase, ...base]; + } + return base; + }, [debouncedQuery, sites, selectedSites, filterTypes]); + + const selectedIds = useMemo( + () => new Set(selectedSites.map((s) => s.siteId)), + [selectedSites] + ); + + const toggleSite = (site: Selectedsite) => { + if (selectedIds.has(site.siteId)) { + onSelectionChange( + selectedSites.filter((s) => s.siteId !== site.siteId) + ); + } else { + onSelectionChange([...selectedSites, site]); + } + }; + + return ( + + setSiteSearchQuery(v)} + /> + + {t("siteNotFound")} + + {sitesShown.map((site) => ( + { + toggleSite(site); + }} + > + {}} + aria-hidden + tabIndex={-1} + /> + {site.name} + + ))} + + + + ); +} diff --git a/src/components/resource-target-address-item.tsx b/src/components/resource-target-address-item.tsx index 851b64b54..c801844ce 100644 --- a/src/components/resource-target-address-item.tsx +++ b/src/components/resource-target-address-item.tsx @@ -12,14 +12,6 @@ import { useTranslations } from "next-intl"; import { useMemo, useState } from "react"; import { ContainersSelector } from "./ContainersSelector"; import { Button } from "./ui/button"; -import { - Command, - CommandEmpty, - CommandGroup, - CommandInput, - CommandItem, - CommandList -} from "./ui/command"; import { Input } from "./ui/input"; import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"; import { Select, SelectContent, SelectItem, SelectTrigger } from "./ui/select"; @@ -212,6 +204,12 @@ export function ResourceTargetAddressItem({ proxyTarget.port === 0 ? "" : proxyTarget.port } className="w-18.75 px-2 border-none placeholder-gray-400 rounded-l-xs" + type="number" + onKeyDown={(e) => { + if (["e", "E", "+", "-", "."].includes(e.key)) { + e.preventDefault(); + } + }} onBlur={(e) => { const value = parseInt(e.target.value, 10); if (!isNaN(value) && value > 0) { @@ -227,6 +225,7 @@ export function ResourceTargetAddressItem({ } }} /> +
); diff --git a/src/components/ui/checkbox.tsx b/src/components/ui/checkbox.tsx index 261655bb0..5cffd8978 100644 --- a/src/components/ui/checkbox.tsx +++ b/src/components/ui/checkbox.tsx @@ -43,8 +43,8 @@ const Checkbox = React.forwardRef< className={cn(checkboxVariants({ variant }), className)} {...props} > - - + + )); diff --git a/src/lib/queries.ts b/src/lib/queries.ts index 2fd34e8ac..d7822d6cf 100644 --- a/src/lib/queries.ts +++ b/src/lib/queries.ts @@ -155,7 +155,8 @@ export const orgQueries = { queryKey: ["ORG", orgId, "SITES", { query, perPage }] as const, queryFn: async ({ signal, meta }) => { const sp = new URLSearchParams({ - pageSize: perPage.toString() + pageSize: perPage.toString(), + status: "approved" }); if (query?.trim()) {