Compare commits

..

1 Commits

Author SHA1 Message Date
Owen
48abb9e98c Breakout sites tables 2026-04-08 22:04:12 -04:00
52 changed files with 1064 additions and 1522 deletions

View File

@@ -2113,11 +2113,9 @@
"addDomainToEnableCustomAuthPages": "Users will be able to access the organization's login page and complete resource authentication using this domain.", "addDomainToEnableCustomAuthPages": "Users will be able to access the organization's login page and complete resource authentication using this domain.",
"selectDomainForOrgAuthPage": "Select a domain for the organization's authentication page", "selectDomainForOrgAuthPage": "Select a domain for the organization's authentication page",
"domainPickerProvidedDomain": "Provided Domain", "domainPickerProvidedDomain": "Provided Domain",
"domainPickerFreeProvidedDomain": "Provided Domain", "domainPickerFreeProvidedDomain": "Free Provided Domain",
"domainPickerFreeDomainsPaidFeature": "Provided domains are a paid feature. Subscribe to get a domain included with your plan — no need to bring your own.",
"domainPickerVerified": "Verified", "domainPickerVerified": "Verified",
"domainPickerUnverified": "Unverified", "domainPickerUnverified": "Unverified",
"domainPickerManual": "Manual",
"domainPickerInvalidSubdomainStructure": "This subdomain contains invalid characters or structure. It will be sanitized automatically when you save.", "domainPickerInvalidSubdomainStructure": "This subdomain contains invalid characters or structure. It will be sanitized automatically when you save.",
"domainPickerError": "Error", "domainPickerError": "Error",
"domainPickerErrorLoadDomains": "Failed to load organization domains", "domainPickerErrorLoadDomains": "Failed to load organization domains",

View File

@@ -89,12 +89,8 @@ export const sites = pgTable("sites", {
name: varchar("name").notNull(), name: varchar("name").notNull(),
pubKey: varchar("pubKey"), pubKey: varchar("pubKey"),
subnet: varchar("subnet"), subnet: varchar("subnet"),
megabytesIn: real("bytesIn").default(0),
megabytesOut: real("bytesOut").default(0),
lastBandwidthUpdate: varchar("lastBandwidthUpdate"),
type: varchar("type").notNull(), // "newt" or "wireguard" type: varchar("type").notNull(), // "newt" or "wireguard"
online: boolean("online").notNull().default(false), online: boolean("online").notNull().default(false),
lastPing: integer("lastPing"),
address: varchar("address"), address: varchar("address"),
endpoint: varchar("endpoint"), endpoint: varchar("endpoint"),
publicKey: varchar("publicKey"), publicKey: varchar("publicKey"),
@@ -222,18 +218,12 @@ export const exitNodes = pgTable("exitNodes", {
export const siteResources = pgTable("siteResources", { export const siteResources = pgTable("siteResources", {
// this is for the clients // this is for the clients
siteResourceId: serial("siteResourceId").primaryKey(), siteResourceId: serial("siteResourceId").primaryKey(),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
orgId: varchar("orgId") orgId: varchar("orgId")
.notNull() .notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }), .references(() => orgs.orgId, { onDelete: "cascade" }),
networkId: integer("networkId").references(() => networks.networkId, {
onDelete: "set null"
}),
defaultNetworkId: integer("defaultNetworkId").references(
() => networks.networkId,
{
onDelete: "restrict"
}
),
niceId: varchar("niceId").notNull(), niceId: varchar("niceId").notNull(),
name: varchar("name").notNull(), name: varchar("name").notNull(),
mode: varchar("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port" mode: varchar("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port"
@@ -253,32 +243,6 @@ export const siteResources = pgTable("siteResources", {
.default("site") .default("site")
}); });
export const networks = pgTable("networks", {
networkId: serial("networkId").primaryKey(),
niceId: text("niceId"),
name: text("name"),
scope: varchar("scope")
.$type<"global" | "resource">()
.notNull()
.default("global"),
orgId: varchar("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
})
.notNull()
});
export const siteNetworks = pgTable("siteNetworks", {
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, {
onDelete: "cascade"
}),
networkId: integer("networkId")
.notNull()
.references(() => networks.networkId, { onDelete: "cascade" })
});
export const clientSiteResources = pgTable("clientSiteResources", { export const clientSiteResources = pgTable("clientSiteResources", {
clientId: integer("clientId") clientId: integer("clientId")
.notNull() .notNull()
@@ -761,10 +725,7 @@ export const clients = pgTable("clients", {
name: varchar("name").notNull(), name: varchar("name").notNull(),
pubKey: varchar("pubKey"), pubKey: varchar("pubKey"),
subnet: varchar("subnet").notNull(), subnet: varchar("subnet").notNull(),
megabytesIn: real("bytesIn"),
megabytesOut: real("bytesOut"),
lastBandwidthUpdate: varchar("lastBandwidthUpdate"),
lastPing: integer("lastPing"),
type: varchar("type").notNull(), // "olm" type: varchar("type").notNull(), // "olm"
online: boolean("online").notNull().default(false), online: boolean("online").notNull().default(false),
// endpoint: varchar("endpoint"), // endpoint: varchar("endpoint"),
@@ -777,6 +738,42 @@ export const clients = pgTable("clients", {
>() >()
}); });
export const sitePing = pgTable("sitePing", {
siteId: integer("siteId")
.primaryKey()
.references(() => sites.siteId, { onDelete: "cascade" })
.notNull(),
lastPing: integer("lastPing")
});
export const siteBandwidth = pgTable("siteBandwidth", {
siteId: integer("siteId")
.primaryKey()
.references(() => sites.siteId, { onDelete: "cascade" })
.notNull(),
megabytesIn: real("bytesIn").default(0),
megabytesOut: real("bytesOut").default(0),
lastBandwidthUpdate: integer("lastBandwidthUpdate") // unix epoch
});
export const clientPing = pgTable("clientPing", {
clientId: integer("clientId")
.primaryKey()
.references(() => clients.clientId, { onDelete: "cascade" })
.notNull(),
lastPing: integer("lastPing")
});
export const clientBandwidth = pgTable("clientBandwidth", {
clientId: integer("clientId")
.primaryKey()
.references(() => clients.clientId, { onDelete: "cascade" })
.notNull(),
megabytesIn: real("bytesIn"),
megabytesOut: real("bytesOut"),
lastBandwidthUpdate: integer("lastBandwidthUpdate") // unix epoch
});
export const clientSitesAssociationsCache = pgTable( export const clientSitesAssociationsCache = pgTable(
"clientSitesAssociationsCache", "clientSitesAssociationsCache",
{ {
@@ -1138,4 +1135,7 @@ export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
export type RoundTripMessageTracker = InferSelectModel< export type RoundTripMessageTracker = InferSelectModel<
typeof roundTripMessageTracker typeof roundTripMessageTracker
>; >;
export type Network = InferSelectModel<typeof networks>; export type SitePing = typeof sitePing.$inferSelect;
export type SiteBandwidth = typeof siteBandwidth.$inferSelect;
export type ClientPing = typeof clientPing.$inferSelect;
export type ClientBandwidth = typeof clientBandwidth.$inferSelect;

View File

@@ -92,18 +92,11 @@ export const sites = sqliteTable("sites", {
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null" onDelete: "set null"
}), }),
networkId: integer("networkId").references(() => networks.networkId, {
onDelete: "set null"
}),
name: text("name").notNull(), name: text("name").notNull(),
pubKey: text("pubKey"), pubKey: text("pubKey"),
subnet: text("subnet"), subnet: text("subnet"),
megabytesIn: integer("bytesIn").default(0),
megabytesOut: integer("bytesOut").default(0),
lastBandwidthUpdate: text("lastBandwidthUpdate"),
type: text("type").notNull(), // "newt" or "wireguard" type: text("type").notNull(), // "newt" or "wireguard"
online: integer("online", { mode: "boolean" }).notNull().default(false), online: integer("online", { mode: "boolean" }).notNull().default(false),
lastPing: integer("lastPing"),
// exit node stuff that is how to connect to the site when it has a wg server // exit node stuff that is how to connect to the site when it has a wg server
address: text("address"), // this is the address of the wireguard interface in newt address: text("address"), // this is the address of the wireguard interface in newt
@@ -253,16 +246,12 @@ export const siteResources = sqliteTable("siteResources", {
siteResourceId: integer("siteResourceId").primaryKey({ siteResourceId: integer("siteResourceId").primaryKey({
autoIncrement: true autoIncrement: true
}), }),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
orgId: text("orgId") orgId: text("orgId")
.notNull() .notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }), .references(() => orgs.orgId, { onDelete: "cascade" }),
networkId: integer("networkId").references(() => networks.networkId, {
onDelete: "set null"
}),
defaultNetworkId: integer("defaultNetworkId").references(
() => networks.networkId,
{ onDelete: "restrict" }
),
niceId: text("niceId").notNull(), niceId: text("niceId").notNull(),
name: text("name").notNull(), name: text("name").notNull(),
mode: text("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port" mode: text("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port"
@@ -284,30 +273,6 @@ export const siteResources = sqliteTable("siteResources", {
.default("site") .default("site")
}); });
export const networks = sqliteTable("networks", {
networkId: integer("networkId").primaryKey({ autoIncrement: true }),
niceId: text("niceId"),
name: text("name"),
scope: text("scope")
.$type<"global" | "resource">()
.notNull()
.default("global"),
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" })
});
export const siteNetworks = sqliteTable("siteNetworks", {
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, {
onDelete: "cascade"
}),
networkId: integer("networkId")
.notNull()
.references(() => networks.networkId, { onDelete: "cascade" })
});
export const clientSiteResources = sqliteTable("clientSiteResources", { export const clientSiteResources = sqliteTable("clientSiteResources", {
clientId: integer("clientId") clientId: integer("clientId")
.notNull() .notNull()
@@ -430,10 +395,7 @@ export const clients = sqliteTable("clients", {
pubKey: text("pubKey"), pubKey: text("pubKey"),
olmId: text("olmId"), // to lock it to a specific olm optionally olmId: text("olmId"), // to lock it to a specific olm optionally
subnet: text("subnet").notNull(), subnet: text("subnet").notNull(),
megabytesIn: integer("bytesIn"),
megabytesOut: integer("bytesOut"),
lastBandwidthUpdate: text("lastBandwidthUpdate"),
lastPing: integer("lastPing"),
type: text("type").notNull(), // "olm" type: text("type").notNull(), // "olm"
online: integer("online", { mode: "boolean" }).notNull().default(false), online: integer("online", { mode: "boolean" }).notNull().default(false),
// endpoint: text("endpoint"), // endpoint: text("endpoint"),
@@ -445,6 +407,42 @@ export const clients = sqliteTable("clients", {
>() >()
}); });
export const sitePing = sqliteTable("sitePing", {
siteId: integer("siteId")
.primaryKey()
.references(() => sites.siteId, { onDelete: "cascade" })
.notNull(),
lastPing: integer("lastPing")
});
export const siteBandwidth = sqliteTable("siteBandwidth", {
siteId: integer("siteId")
.primaryKey()
.references(() => sites.siteId, { onDelete: "cascade" })
.notNull(),
megabytesIn: integer("bytesIn").default(0),
megabytesOut: integer("bytesOut").default(0),
lastBandwidthUpdate: integer("lastBandwidthUpdate") // unix epoch
});
export const clientPing = sqliteTable("clientPing", {
clientId: integer("clientId")
.primaryKey()
.references(() => clients.clientId, { onDelete: "cascade" })
.notNull(),
lastPing: integer("lastPing")
});
export const clientBandwidth = sqliteTable("clientBandwidth", {
clientId: integer("clientId")
.primaryKey()
.references(() => clients.clientId, { onDelete: "cascade" })
.notNull(),
megabytesIn: integer("bytesIn"),
megabytesOut: integer("bytesOut"),
lastBandwidthUpdate: integer("lastBandwidthUpdate") // unix epoch
});
export const clientSitesAssociationsCache = sqliteTable( export const clientSitesAssociationsCache = sqliteTable(
"clientSitesAssociationsCache", "clientSitesAssociationsCache",
{ {
@@ -1226,7 +1224,6 @@ export type ApiKey = InferSelectModel<typeof apiKeys>;
export type ApiKeyAction = InferSelectModel<typeof apiKeyActions>; export type ApiKeyAction = InferSelectModel<typeof apiKeyActions>;
export type ApiKeyOrg = InferSelectModel<typeof apiKeyOrg>; export type ApiKeyOrg = InferSelectModel<typeof apiKeyOrg>;
export type SiteResource = InferSelectModel<typeof siteResources>; export type SiteResource = InferSelectModel<typeof siteResources>;
export type Network = InferSelectModel<typeof networks>;
export type OrgDomains = InferSelectModel<typeof orgDomains>; export type OrgDomains = InferSelectModel<typeof orgDomains>;
export type SetupToken = InferSelectModel<typeof setupTokens>; export type SetupToken = InferSelectModel<typeof setupTokens>;
export type HostMeta = InferSelectModel<typeof hostMeta>; export type HostMeta = InferSelectModel<typeof hostMeta>;
@@ -1241,3 +1238,7 @@ export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
export type RoundTripMessageTracker = InferSelectModel< export type RoundTripMessageTracker = InferSelectModel<
typeof roundTripMessageTracker typeof roundTripMessageTracker
>; >;
export type SitePing = typeof sitePing.$inferSelect;
export type SiteBandwidth = typeof siteBandwidth.$inferSelect;
export type ClientPing = typeof clientPing.$inferSelect;
export type ClientBandwidth = typeof clientBandwidth.$inferSelect;

View File

@@ -19,8 +19,7 @@ export enum TierFeature {
SshPam = "sshPam", SshPam = "sshPam",
FullRbac = "fullRbac", FullRbac = "fullRbac",
SiteProvisioningKeys = "siteProvisioningKeys", // handle downgrade by revoking keys if needed SiteProvisioningKeys = "siteProvisioningKeys", // handle downgrade by revoking keys if needed
SIEM = "siem", // handle downgrade by disabling SIEM integrations SIEM = "siem" // handle downgrade by disabling SIEM integrations
DomainNamespaces = "domainNamespaces" // handle downgrade by removing custom domain namespaces
} }
export const tierMatrix: Record<TierFeature, Tier[]> = { export const tierMatrix: Record<TierFeature, Tier[]> = {
@@ -57,6 +56,5 @@ export const tierMatrix: Record<TierFeature, Tier[]> = {
[TierFeature.SshPam]: ["tier1", "tier3", "enterprise"], [TierFeature.SshPam]: ["tier1", "tier3", "enterprise"],
[TierFeature.FullRbac]: ["tier1", "tier2", "tier3", "enterprise"], [TierFeature.FullRbac]: ["tier1", "tier2", "tier3", "enterprise"],
[TierFeature.SiteProvisioningKeys]: ["tier3", "enterprise"], [TierFeature.SiteProvisioningKeys]: ["tier3", "enterprise"],
[TierFeature.SIEM]: ["enterprise"], [TierFeature.SIEM]: ["enterprise"]
[TierFeature.DomainNamespaces]: ["tier1", "tier2", "tier3", "enterprise"]
}; };

View File

@@ -121,8 +121,8 @@ export async function applyBlueprint({
for (const result of clientResourcesResults) { for (const result of clientResourcesResults) {
if ( if (
result.oldSiteResource && result.oldSiteResource &&
JSON.stringify(result.newSites?.sort()) !== result.oldSiteResource.siteId !=
JSON.stringify(result.oldSites?.sort()) result.newSiteResource.siteId
) { ) {
// query existing associations // query existing associations
const existingRoleIds = await trx const existingRoleIds = await trx
@@ -222,46 +222,38 @@ export async function applyBlueprint({
trx trx
); );
} else { } else {
let good = true; const [newSite] = await trx
for (const newSite of result.newSites) { .select()
const [site] = await trx .from(sites)
.select() .innerJoin(newts, eq(sites.siteId, newts.siteId))
.from(sites) .where(
.innerJoin(newts, eq(sites.siteId, newts.siteId)) and(
.where( eq(sites.siteId, result.newSiteResource.siteId),
and( eq(sites.orgId, orgId),
eq(sites.siteId, newSite.siteId), eq(sites.type, "newt"),
eq(sites.orgId, orgId), isNotNull(sites.pubKey)
eq(sites.type, "newt"),
isNotNull(sites.pubKey)
)
) )
.limit(1); )
.limit(1);
if (!site) {
logger.debug(
`No newt sites found for client resource ${result.newSiteResource.siteResourceId}, skipping target update`
);
good = false;
break;
}
if (!newSite) {
logger.debug( logger.debug(
`Updating client resource ${result.newSiteResource.siteResourceId} on site ${newSite.siteId}` `No newt site found for client resource ${result.newSiteResource.siteResourceId}, skipping target update`
); );
}
if (!good) {
continue; continue;
} }
logger.debug(
`Updating client resource ${result.newSiteResource.siteResourceId} on site ${newSite.sites.siteId}`
);
await handleMessagingForUpdatedSiteResource( await handleMessagingForUpdatedSiteResource(
result.oldSiteResource, result.oldSiteResource,
result.newSiteResource, result.newSiteResource,
result.newSites.map((site) => ({ {
siteId: site.siteId, siteId: newSite.sites.siteId,
orgId: result.newSiteResource.orgId orgId: newSite.sites.orgId
})), },
trx trx
); );
} }

View File

@@ -3,15 +3,12 @@ import {
clientSiteResources, clientSiteResources,
roles, roles,
roleSiteResources, roleSiteResources,
Site,
SiteResource, SiteResource,
siteNetworks,
siteResources, siteResources,
Transaction, Transaction,
userOrgs, userOrgs,
users, users,
userSiteResources, userSiteResources
networks
} from "@server/db"; } from "@server/db";
import { sites } from "@server/db"; import { sites } from "@server/db";
import { eq, and, ne, inArray, or } from "drizzle-orm"; import { eq, and, ne, inArray, or } from "drizzle-orm";
@@ -22,8 +19,6 @@ import { getNextAvailableAliasAddress } from "../ip";
export type ClientResourcesResults = { export type ClientResourcesResults = {
newSiteResource: SiteResource; newSiteResource: SiteResource;
oldSiteResource?: SiteResource; oldSiteResource?: SiteResource;
newSites: { siteId: number }[];
oldSites: { siteId: number }[];
}[]; }[];
export async function updateClientResources( export async function updateClientResources(
@@ -48,70 +43,36 @@ export async function updateClientResources(
) )
.limit(1); .limit(1);
const existingSiteIds = existingResource?.networkId const resourceSiteId = resourceData.site;
? await trx let site;
.select({ siteId: sites.siteId })
.from(siteNetworks)
.where(eq(siteNetworks.networkId, existingResource.networkId))
: [];
let allSites: { siteId: number }[] = []; if (resourceSiteId) {
if (resourceData.site) { // Look up site by niceId
let siteSingle; [site] = await trx
const resourceSiteId = resourceData.site; .select({ siteId: sites.siteId })
.from(sites)
if (resourceSiteId) { .where(
// Look up site by niceId and(
[siteSingle] = await trx eq(sites.niceId, resourceSiteId),
.select({ siteId: sites.siteId }) eq(sites.orgId, orgId)
.from(sites)
.where(
and(
eq(sites.niceId, resourceSiteId),
eq(sites.orgId, orgId)
)
) )
.limit(1); )
} else if (siteId) { .limit(1);
// Use the provided siteId directly, but verify it belongs to the org } else if (siteId) {
[siteSingle] = await trx // Use the provided siteId directly, but verify it belongs to the org
.select({ siteId: sites.siteId }) [site] = await trx
.from(sites) .select({ siteId: sites.siteId })
.where( .from(sites)
and(eq(sites.siteId, siteId), eq(sites.orgId, orgId)) .where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId)))
) .limit(1);
.limit(1); } else {
} else { throw new Error(`Target site is required`);
throw new Error(`Target site is required`);
}
if (!siteSingle) {
throw new Error(
`Site not found: ${resourceSiteId} in org ${orgId}`
);
}
allSites.push(siteSingle);
} }
if (resourceData.sites) { if (!site) {
for (const siteNiceId of resourceData.sites) { throw new Error(
const [site] = await trx `Site not found: ${resourceSiteId} in org ${orgId}`
.select({ siteId: sites.siteId }) );
.from(sites)
.where(
and(
eq(sites.niceId, siteNiceId),
eq(sites.orgId, orgId)
)
)
.limit(1);
if (!site) {
throw new Error(
`Site not found: ${siteId} in org ${orgId}`
);
}
allSites.push(site);
}
} }
if (existingResource) { if (existingResource) {
@@ -120,6 +81,7 @@ export async function updateClientResources(
.update(siteResources) .update(siteResources)
.set({ .set({
name: resourceData.name || resourceNiceId, name: resourceData.name || resourceNiceId,
siteId: site.siteId,
mode: resourceData.mode, mode: resourceData.mode,
destination: resourceData.destination, destination: resourceData.destination,
enabled: true, // hardcoded for now enabled: true, // hardcoded for now
@@ -140,21 +102,6 @@ export async function updateClientResources(
const siteResourceId = existingResource.siteResourceId; const siteResourceId = existingResource.siteResourceId;
const orgId = existingResource.orgId; const orgId = existingResource.orgId;
if (updatedResource.networkId) {
await trx
.delete(siteNetworks)
.where(
eq(siteNetworks.networkId, updatedResource.networkId)
);
for (const site of allSites) {
await trx.insert(siteNetworks).values({
siteId: site.siteId,
networkId: updatedResource.networkId
});
}
}
await trx await trx
.delete(clientSiteResources) .delete(clientSiteResources)
.where(eq(clientSiteResources.siteResourceId, siteResourceId)); .where(eq(clientSiteResources.siteResourceId, siteResourceId));
@@ -257,9 +204,7 @@ export async function updateClientResources(
results.push({ results.push({
newSiteResource: updatedResource, newSiteResource: updatedResource,
oldSiteResource: existingResource, oldSiteResource: existingResource
newSites: allSites,
oldSites: existingSiteIds
}); });
} else { } else {
let aliasAddress: string | null = null; let aliasAddress: string | null = null;
@@ -268,22 +213,13 @@ export async function updateClientResources(
aliasAddress = await getNextAvailableAliasAddress(orgId); aliasAddress = await getNextAvailableAliasAddress(orgId);
} }
const [network] = await trx
.insert(networks)
.values({
scope: "resource",
orgId: orgId
})
.returning();
// Create new resource // Create new resource
const [newResource] = await trx const [newResource] = await trx
.insert(siteResources) .insert(siteResources)
.values({ .values({
orgId: orgId, orgId: orgId,
siteId: site.siteId,
niceId: resourceNiceId, niceId: resourceNiceId,
networkId: network.networkId,
defaultNetworkId: network.networkId,
name: resourceData.name || resourceNiceId, name: resourceData.name || resourceNiceId,
mode: resourceData.mode, mode: resourceData.mode,
destination: resourceData.destination, destination: resourceData.destination,
@@ -299,13 +235,6 @@ export async function updateClientResources(
const siteResourceId = newResource.siteResourceId; const siteResourceId = newResource.siteResourceId;
for (const site of allSites) {
await trx.insert(siteNetworks).values({
siteId: site.siteId,
networkId: network.networkId
});
}
const [adminRole] = await trx const [adminRole] = await trx
.select() .select()
.from(roles) .from(roles)
@@ -395,11 +324,7 @@ export async function updateClientResources(
`Created new client resource ${newResource.name} (${newResource.siteResourceId}) for org ${orgId}` `Created new client resource ${newResource.name} (${newResource.siteResourceId}) for org ${orgId}`
); );
results.push({ results.push({ newSiteResource: newResource });
newSiteResource: newResource,
newSites: allSites,
oldSites: existingSiteIds
});
} }
} }

View File

@@ -326,8 +326,7 @@ export const ClientResourceSchema = z
.object({ .object({
name: z.string().min(1).max(255), name: z.string().min(1).max(255),
mode: z.enum(["host", "cidr"]), mode: z.enum(["host", "cidr"]),
site: z.string(), // DEPRECATED IN FAVOR OF sites site: z.string(),
sites: z.array(z.string()).optional().default([]),
// protocol: z.enum(["tcp", "udp"]).optional(), // protocol: z.enum(["tcp", "udp"]).optional(),
// proxyPort: z.int().positive().optional(), // proxyPort: z.int().positive().optional(),
// destinationPort: z.int().positive().optional(), // destinationPort: z.int().positive().optional(),

View File

@@ -11,11 +11,11 @@ import {
roleSiteResources, roleSiteResources,
Site, Site,
SiteResource, SiteResource,
siteNetworks,
siteResources, siteResources,
sites, sites,
Transaction, Transaction,
userOrgRoles, userOrgRoles,
userOrgs,
userSiteResources userSiteResources
} from "@server/db"; } from "@server/db";
import { and, eq, inArray, ne } from "drizzle-orm"; import { and, eq, inArray, ne } from "drizzle-orm";
@@ -48,23 +48,15 @@ export async function getClientSiteResourceAccess(
siteResource: SiteResource, siteResource: SiteResource,
trx: Transaction | typeof db = db trx: Transaction | typeof db = db
) { ) {
// get all sites associated with this siteResource via its network // get the site
const sitesList = siteResource.networkId const [site] = await trx
? await trx .select()
.select() .from(sites)
.from(sites) .where(eq(sites.siteId, siteResource.siteId))
.innerJoin( .limit(1);
siteNetworks,
eq(siteNetworks.siteId, sites.siteId)
)
.where(eq(siteNetworks.networkId, siteResource.networkId))
.then((rows) => rows.map((row) => row.sites))
: [];
if (sitesList.length === 0) { if (!site) {
logger.warn( throw new Error(`Site with ID ${siteResource.siteId} not found`);
`No sites found for siteResource ${siteResource.siteResourceId} with networkId ${siteResource.networkId}`
);
} }
const roleIds = await trx const roleIds = await trx
@@ -145,7 +137,7 @@ export async function getClientSiteResourceAccess(
const mergedAllClientIds = mergedAllClients.map((c) => c.clientId); const mergedAllClientIds = mergedAllClients.map((c) => c.clientId);
return { return {
sitesList, site,
mergedAllClients, mergedAllClients,
mergedAllClientIds mergedAllClientIds
}; };
@@ -161,51 +153,40 @@ export async function rebuildClientAssociationsFromSiteResource(
subnet: string | null; subnet: string | null;
}[]; }[];
}> { }> {
const { sitesList, mergedAllClients, mergedAllClientIds } = const siteId = siteResource.siteId;
const { site, mergedAllClients, mergedAllClientIds } =
await getClientSiteResourceAccess(siteResource, trx); await getClientSiteResourceAccess(siteResource, trx);
/////////// process the client-siteResource associations /////////// /////////// process the client-siteResource associations ///////////
// get all of the clients associated with other resources in the same network, // get all of the clients associated with other resources on this site
// joined through siteNetworks so we know which siteId each client belongs to const allUpdatedClientsFromOtherResourcesOnThisSite = await trx
const allUpdatedClientsFromOtherResourcesOnThisSite = siteResource.networkId .select({
? await trx clientId: clientSiteResourcesAssociationsCache.clientId
.select({ })
clientId: clientSiteResourcesAssociationsCache.clientId, .from(clientSiteResourcesAssociationsCache)
siteId: siteNetworks.siteId .innerJoin(
}) siteResources,
.from(clientSiteResourcesAssociationsCache) eq(
.innerJoin( clientSiteResourcesAssociationsCache.siteResourceId,
siteResources, siteResources.siteResourceId
eq( )
clientSiteResourcesAssociationsCache.siteResourceId, )
siteResources.siteResourceId .where(
) and(
) eq(siteResources.siteId, siteId),
.innerJoin( ne(siteResources.siteResourceId, siteResource.siteResourceId)
siteNetworks, )
eq(siteNetworks.networkId, siteResources.networkId) );
)
.where(
and(
eq(siteResources.networkId, siteResource.networkId),
ne(
siteResources.siteResourceId,
siteResource.siteResourceId
)
)
)
: [];
// Build a per-site map so the loop below can check by siteId rather than const allClientIdsFromOtherResourcesOnThisSite = Array.from(
// across the entire network. new Set(
const clientsFromOtherResourcesBySite = new Map<number, Set<number>>(); allUpdatedClientsFromOtherResourcesOnThisSite.map(
for (const row of allUpdatedClientsFromOtherResourcesOnThisSite) { (row) => row.clientId
if (!clientsFromOtherResourcesBySite.has(row.siteId)) { )
clientsFromOtherResourcesBySite.set(row.siteId, new Set()); )
} );
clientsFromOtherResourcesBySite.get(row.siteId)!.add(row.clientId);
}
const existingClientSiteResources = await trx const existingClientSiteResources = await trx
.select({ .select({
@@ -279,90 +260,82 @@ export async function rebuildClientAssociationsFromSiteResource(
/////////// process the client-site associations /////////// /////////// process the client-site associations ///////////
for (const site of sitesList) { const existingClientSites = await trx
const siteId = site.siteId; .select({
clientId: clientSitesAssociationsCache.clientId
})
.from(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.siteId, siteResource.siteId));
const existingClientSites = await trx const existingClientSiteIds = existingClientSites.map(
.select({ (row) => row.clientId
clientId: clientSitesAssociationsCache.clientId );
})
.from(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.siteId, siteId));
const existingClientSiteIds = existingClientSites.map( // Get full client details for existing clients (needed for sending delete messages)
(row) => row.clientId const existingClients = await trx
); .select({
clientId: clients.clientId,
pubKey: clients.pubKey,
subnet: clients.subnet
})
.from(clients)
.where(inArray(clients.clientId, existingClientSiteIds));
// Get full client details for existing clients (needed for sending delete messages) const clientSitesToAdd = mergedAllClientIds.filter(
const existingClients = (clientId) =>
existingClientSiteIds.length > 0 !existingClientSiteIds.includes(clientId) &&
? await trx !allClientIdsFromOtherResourcesOnThisSite.includes(clientId) // dont remove if there is still another connection for another site resource
.select({ );
clientId: clients.clientId,
pubKey: clients.pubKey,
subnet: clients.subnet
})
.from(clients)
.where(inArray(clients.clientId, existingClientSiteIds))
: [];
const otherResourceClientIds = clientsFromOtherResourcesBySite.get(siteId) ?? new Set<number>(); const clientSitesToInsert = clientSitesToAdd.map((clientId) => ({
clientId,
siteId
}));
const clientSitesToAdd = mergedAllClientIds.filter( if (clientSitesToInsert.length > 0) {
(clientId) => await trx
!existingClientSiteIds.includes(clientId) && .insert(clientSitesAssociationsCache)
!otherResourceClientIds.has(clientId) // dont add if already connected via another site resource .values(clientSitesToInsert)
); .returning();
const clientSitesToInsert = clientSitesToAdd.map((clientId) => ({
clientId,
siteId
}));
if (clientSitesToInsert.length > 0) {
await trx
.insert(clientSitesAssociationsCache)
.values(clientSitesToInsert)
.returning();
}
// Now remove any client-site associations that should no longer exist
const clientSitesToRemove = existingClientSiteIds.filter(
(clientId) =>
!mergedAllClientIds.includes(clientId) &&
!otherResourceClientIds.has(clientId) // dont remove if there is still another connection for another site resource
);
if (clientSitesToRemove.length > 0) {
await trx
.delete(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.siteId, siteId),
inArray(
clientSitesAssociationsCache.clientId,
clientSitesToRemove
)
)
);
}
// Now handle the messages to add/remove peers on both the newt and olm sides
await handleMessagesForSiteClients(
site,
siteId,
mergedAllClients,
existingClients,
clientSitesToAdd,
clientSitesToRemove,
trx
);
} }
// Now remove any client-site associations that should no longer exist
const clientSitesToRemove = existingClientSiteIds.filter(
(clientId) =>
!mergedAllClientIds.includes(clientId) &&
!allClientIdsFromOtherResourcesOnThisSite.includes(clientId) // dont remove if there is still another connection for another site resource
);
if (clientSitesToRemove.length > 0) {
await trx
.delete(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.siteId, siteId),
inArray(
clientSitesAssociationsCache.clientId,
clientSitesToRemove
)
)
);
}
/////////// send the messages ///////////
// Now handle the messages to add/remove peers on both the newt and olm sides
await handleMessagesForSiteClients(
site,
siteId,
mergedAllClients,
existingClients,
clientSitesToAdd,
clientSitesToRemove,
trx
);
// Handle subnet proxy target updates for the resource associations // Handle subnet proxy target updates for the resource associations
await handleSubnetProxyTargetUpdates( await handleSubnetProxyTargetUpdates(
siteResource, siteResource,
sitesList,
mergedAllClients, mergedAllClients,
existingResourceClients, existingResourceClients,
clientSiteResourcesToAdd, clientSiteResourcesToAdd,
@@ -651,7 +624,6 @@ export async function updateClientSiteDestinations(
async function handleSubnetProxyTargetUpdates( async function handleSubnetProxyTargetUpdates(
siteResource: SiteResource, siteResource: SiteResource,
sitesList: Site[],
allClients: { allClients: {
clientId: number; clientId: number;
pubKey: string | null; pubKey: string | null;
@@ -666,138 +638,125 @@ async function handleSubnetProxyTargetUpdates(
clientSiteResourcesToRemove: number[], clientSiteResourcesToRemove: number[],
trx: Transaction | typeof db = db trx: Transaction | typeof db = db
): Promise<void> { ): Promise<void> {
const proxyJobs: Promise<any>[] = []; // Get the newt for this site
const olmJobs: Promise<any>[] = []; const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, siteResource.siteId))
.limit(1);
for (const siteData of sitesList) { if (!newt) {
const siteId = siteData.siteId; logger.warn(
`Newt not found for site ${siteResource.siteId}, skipping subnet proxy target updates`
);
return;
}
// Get the newt for this site const proxyJobs = [];
const [newt] = await trx const olmJobs = [];
.select() // Generate targets for added associations
.from(newts) if (clientSiteResourcesToAdd.length > 0) {
.where(eq(newts.siteId, siteId)) const addedClients = allClients.filter((client) =>
.limit(1); clientSiteResourcesToAdd.includes(client.clientId)
);
if (!newt) { if (addedClients.length > 0) {
logger.warn( const targetToAdd = generateSubnetProxyTargetV2(
`Newt not found for site ${siteId}, skipping subnet proxy target updates` siteResource,
); addedClients
continue;
}
// Generate targets for added associations
if (clientSiteResourcesToAdd.length > 0) {
const addedClients = allClients.filter((client) =>
clientSiteResourcesToAdd.includes(client.clientId)
); );
if (addedClients.length > 0) { if (targetToAdd) {
const targetToAdd = generateSubnetProxyTargetV2( proxyJobs.push(
siteResource, addSubnetProxyTargets(
addedClients newt.newtId,
[targetToAdd],
newt.version
)
); );
}
if (targetToAdd) { for (const client of addedClients) {
proxyJobs.push( olmJobs.push(
addSubnetProxyTargets( addPeerData(
newt.newtId, client.clientId,
[targetToAdd], siteResource.siteId,
newt.version generateRemoteSubnets([siteResource]),
) generateAliasConfig([siteResource])
); )
} );
for (const client of addedClients) {
olmJobs.push(
addPeerData(
client.clientId,
siteId,
generateRemoteSubnets([siteResource]),
generateAliasConfig([siteResource])
)
);
}
} }
} }
}
// here we use the existingSiteResource from BEFORE we updated the destination so we dont need to worry about updating destinations here // here we use the existingSiteResource from BEFORE we updated the destination so we dont need to worry about updating destinations here
// Generate targets for removed associations // Generate targets for removed associations
if (clientSiteResourcesToRemove.length > 0) { if (clientSiteResourcesToRemove.length > 0) {
const removedClients = existingClients.filter((client) => const removedClients = existingClients.filter((client) =>
clientSiteResourcesToRemove.includes(client.clientId) clientSiteResourcesToRemove.includes(client.clientId)
);
if (removedClients.length > 0) {
const targetToRemove = generateSubnetProxyTargetV2(
siteResource,
removedClients
); );
if (removedClients.length > 0) { if (targetToRemove) {
const targetToRemove = generateSubnetProxyTargetV2( proxyJobs.push(
siteResource, removeSubnetProxyTargets(
removedClients newt.newtId,
[targetToRemove],
newt.version
)
); );
}
if (targetToRemove) { for (const client of removedClients) {
proxyJobs.push( // Check if this client still has access to another resource on this site with the same destination
removeSubnetProxyTargets( const destinationStillInUse = await trx
newt.newtId, .select()
[targetToRemove], .from(siteResources)
newt.version .innerJoin(
clientSiteResourcesAssociationsCache,
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
siteResources.siteResourceId
) )
); )
} .where(
and(
for (const client of removedClients) {
// Check if this client still has access to another resource
// on this specific site with the same destination. We scope
// by siteId (via siteNetworks) rather than networkId because
// removePeerData operates per-site — a resource on a different
// site sharing the same network should not block removal here.
const destinationStillInUse = await trx
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq( eq(
clientSiteResourcesAssociationsCache.siteResourceId, clientSiteResourcesAssociationsCache.clientId,
siteResources.siteResourceId client.clientId
),
eq(siteResources.siteId, siteResource.siteId),
eq(
siteResources.destination,
siteResource.destination
),
ne(
siteResources.siteResourceId,
siteResource.siteResourceId
) )
) )
.innerJoin(
siteNetworks,
eq(siteNetworks.networkId, siteResources.networkId)
)
.where(
and(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
),
eq(siteNetworks.siteId, siteId),
eq(
siteResources.destination,
siteResource.destination
),
ne(
siteResources.siteResourceId,
siteResource.siteResourceId
)
)
);
// Only remove remote subnet if no other resource uses the same destination
const remoteSubnetsToRemove =
destinationStillInUse.length > 0
? []
: generateRemoteSubnets([siteResource]);
olmJobs.push(
removePeerData(
client.clientId,
siteId,
remoteSubnetsToRemove,
generateAliasConfig([siteResource])
)
); );
}
// Only remove remote subnet if no other resource uses the same destination
const remoteSubnetsToRemove =
destinationStillInUse.length > 0
? []
: generateRemoteSubnets([siteResource]);
olmJobs.push(
removePeerData(
client.clientId,
siteResource.siteId,
remoteSubnetsToRemove,
generateAliasConfig([siteResource])
)
);
} }
} }
} }
@@ -904,25 +863,10 @@ export async function rebuildClientAssociationsFromClient(
) )
: []; : [];
// Group by siteId for site-level associations — look up via siteNetworks since // Group by siteId for site-level associations
// siteResources no longer carries a direct siteId column. const newSiteIds = Array.from(
const networkIds = Array.from( new Set(newSiteResources.map((sr) => sr.siteId))
new Set(
newSiteResources
.map((sr) => sr.networkId)
.filter((id): id is number => id !== null)
)
); );
const newSiteIds =
networkIds.length > 0
? await trx
.select({ siteId: siteNetworks.siteId })
.from(siteNetworks)
.where(inArray(siteNetworks.networkId, networkIds))
.then((rows) =>
Array.from(new Set(rows.map((r) => r.siteId)))
)
: [];
/////////// Process client-siteResource associations /////////// /////////// Process client-siteResource associations ///////////
@@ -1195,45 +1139,13 @@ async function handleMessagesForClientResources(
resourcesToAdd.includes(r.siteResourceId) resourcesToAdd.includes(r.siteResourceId)
); );
// Build (resource, siteId) pairs by looking up siteNetworks for each resource's networkId
const addedNetworkIds = Array.from(
new Set(
addedResources
.map((r) => r.networkId)
.filter((id): id is number => id !== null)
)
);
const addedSiteNetworkRows =
addedNetworkIds.length > 0
? await trx
.select({
networkId: siteNetworks.networkId,
siteId: siteNetworks.siteId
})
.from(siteNetworks)
.where(inArray(siteNetworks.networkId, addedNetworkIds))
: [];
const addedNetworkToSites = new Map<number, number[]>();
for (const row of addedSiteNetworkRows) {
if (!addedNetworkToSites.has(row.networkId)) {
addedNetworkToSites.set(row.networkId, []);
}
addedNetworkToSites.get(row.networkId)!.push(row.siteId);
}
// Group by site for proxy updates // Group by site for proxy updates
const addedBySite = new Map<number, SiteResource[]>(); const addedBySite = new Map<number, SiteResource[]>();
for (const resource of addedResources) { for (const resource of addedResources) {
const siteIds = if (!addedBySite.has(resource.siteId)) {
resource.networkId != null addedBySite.set(resource.siteId, []);
? (addedNetworkToSites.get(resource.networkId) ?? [])
: [];
for (const siteId of siteIds) {
if (!addedBySite.has(siteId)) {
addedBySite.set(siteId, []);
}
addedBySite.get(siteId)!.push(resource);
} }
addedBySite.get(resource.siteId)!.push(resource);
} }
// Add subnet proxy targets for each site // Add subnet proxy targets for each site
@@ -1275,7 +1187,7 @@ async function handleMessagesForClientResources(
olmJobs.push( olmJobs.push(
addPeerData( addPeerData(
client.clientId, client.clientId,
siteId, resource.siteId,
generateRemoteSubnets([resource]), generateRemoteSubnets([resource]),
generateAliasConfig([resource]) generateAliasConfig([resource])
) )
@@ -1287,7 +1199,7 @@ async function handleMessagesForClientResources(
error.message.includes("not found") error.message.includes("not found")
) { ) {
logger.debug( logger.debug(
`Olm data not found for client ${client.clientId} and site ${siteId}, skipping addition` `Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal`
); );
} else { } else {
throw error; throw error;
@@ -1304,45 +1216,13 @@ async function handleMessagesForClientResources(
.from(siteResources) .from(siteResources)
.where(inArray(siteResources.siteResourceId, resourcesToRemove)); .where(inArray(siteResources.siteResourceId, resourcesToRemove));
// Build (resource, siteId) pairs via siteNetworks
const removedNetworkIds = Array.from(
new Set(
removedResources
.map((r) => r.networkId)
.filter((id): id is number => id !== null)
)
);
const removedSiteNetworkRows =
removedNetworkIds.length > 0
? await trx
.select({
networkId: siteNetworks.networkId,
siteId: siteNetworks.siteId
})
.from(siteNetworks)
.where(inArray(siteNetworks.networkId, removedNetworkIds))
: [];
const removedNetworkToSites = new Map<number, number[]>();
for (const row of removedSiteNetworkRows) {
if (!removedNetworkToSites.has(row.networkId)) {
removedNetworkToSites.set(row.networkId, []);
}
removedNetworkToSites.get(row.networkId)!.push(row.siteId);
}
// Group by site for proxy updates // Group by site for proxy updates
const removedBySite = new Map<number, SiteResource[]>(); const removedBySite = new Map<number, SiteResource[]>();
for (const resource of removedResources) { for (const resource of removedResources) {
const siteIds = if (!removedBySite.has(resource.siteId)) {
resource.networkId != null removedBySite.set(resource.siteId, []);
? (removedNetworkToSites.get(resource.networkId) ?? [])
: [];
for (const siteId of siteIds) {
if (!removedBySite.has(siteId)) {
removedBySite.set(siteId, []);
}
removedBySite.get(siteId)!.push(resource);
} }
removedBySite.get(resource.siteId)!.push(resource);
} }
// Remove subnet proxy targets for each site // Remove subnet proxy targets for each site
@@ -1380,11 +1260,7 @@ async function handleMessagesForClientResources(
} }
try { try {
// Check if this client still has access to another resource // Check if this client still has access to another resource on this site with the same destination
// on this specific site with the same destination. We scope
// by siteId (via siteNetworks) rather than networkId because
// removePeerData operates per-site — a resource on a different
// site sharing the same network should not block removal here.
const destinationStillInUse = await trx const destinationStillInUse = await trx
.select() .select()
.from(siteResources) .from(siteResources)
@@ -1395,17 +1271,13 @@ async function handleMessagesForClientResources(
siteResources.siteResourceId siteResources.siteResourceId
) )
) )
.innerJoin(
siteNetworks,
eq(siteNetworks.networkId, siteResources.networkId)
)
.where( .where(
and( and(
eq( eq(
clientSiteResourcesAssociationsCache.clientId, clientSiteResourcesAssociationsCache.clientId,
client.clientId client.clientId
), ),
eq(siteNetworks.siteId, siteId), eq(siteResources.siteId, resource.siteId),
eq( eq(
siteResources.destination, siteResources.destination,
resource.destination resource.destination
@@ -1427,7 +1299,7 @@ async function handleMessagesForClientResources(
olmJobs.push( olmJobs.push(
removePeerData( removePeerData(
client.clientId, client.clientId,
siteId, resource.siteId,
remoteSubnetsToRemove, remoteSubnetsToRemove,
generateAliasConfig([resource]) generateAliasConfig([resource])
) )
@@ -1439,7 +1311,7 @@ async function handleMessagesForClientResources(
error.message.includes("not found") error.message.includes("not found")
) { ) {
logger.debug( logger.debug(
`Olm data not found for client ${client.clientId} and site ${siteId}, skipping removal` `Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal`
); );
} else { } else {
throw error; throw error;

View File

@@ -3,7 +3,7 @@ import config from "./config";
import { getHostMeta } from "./hostMeta"; import { getHostMeta } from "./hostMeta";
import logger from "@server/logger"; import logger from "@server/logger";
import { apiKeys, db, roles, siteResources } from "@server/db"; import { apiKeys, db, roles, siteResources } from "@server/db";
import { sites, users, orgs, resources, clients, idp } from "@server/db"; import { sites, users, orgs, resources, clients, idp, siteBandwidth } from "@server/db";
import { eq, count, notInArray, and, isNotNull, isNull } from "drizzle-orm"; import { eq, count, notInArray, and, isNotNull, isNull } from "drizzle-orm";
import { APP_VERSION } from "./consts"; import { APP_VERSION } from "./consts";
import crypto from "crypto"; import crypto from "crypto";
@@ -150,12 +150,13 @@ class TelemetryClient {
const siteDetails = await db const siteDetails = await db
.select({ .select({
siteName: sites.name, siteName: sites.name,
megabytesIn: sites.megabytesIn, megabytesIn: siteBandwidth.megabytesIn,
megabytesOut: sites.megabytesOut, megabytesOut: siteBandwidth.megabytesOut,
type: sites.type, type: sites.type,
online: sites.online online: sites.online
}) })
.from(sites); .from(sites)
.leftJoin(siteBandwidth, eq(siteBandwidth.siteId, sites.siteId));
const supporterKey = config.getSupporterData(); const supporterKey = config.getSupporterData();

View File

@@ -18,10 +18,11 @@ import {
subscriptionItems, subscriptionItems,
usage, usage,
sites, sites,
siteBandwidth,
customers, customers,
orgs orgs
} from "@server/db"; } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and, inArray } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
import { getFeatureIdByMetricId, getFeatureIdByPriceId } from "@server/lib/billing/features"; import { getFeatureIdByMetricId, getFeatureIdByPriceId } from "@server/lib/billing/features";
import stripe from "#private/lib/stripe"; import stripe from "#private/lib/stripe";
@@ -253,14 +254,19 @@ export async function handleSubscriptionUpdated(
); );
} }
// Also reset the sites to 0 // Also reset the site bandwidth to 0
await trx await trx
.update(sites) .update(siteBandwidth)
.set({ .set({
megabytesIn: 0, megabytesIn: 0,
megabytesOut: 0 megabytesOut: 0
}) })
.where(eq(sites.orgId, orgId)); .where(
inArray(
siteBandwidth.siteId,
trx.select({ siteId: sites.siteId }).from(sites).where(eq(sites.orgId, orgId))
)
);
}); });
} }
} }

View File

@@ -22,15 +22,11 @@ import { OpenAPITags, registry } from "@server/openApi";
import { db, domainNamespaces, resources } from "@server/db"; import { db, domainNamespaces, resources } from "@server/db";
import { inArray } from "drizzle-orm"; import { inArray } from "drizzle-orm";
import { CheckDomainAvailabilityResponse } from "@server/routers/domain/types"; import { CheckDomainAvailabilityResponse } from "@server/routers/domain/types";
import { build } from "@server/build";
import { isSubscribed } from "#private/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
const paramsSchema = z.strictObject({}); const paramsSchema = z.strictObject({});
const querySchema = z.strictObject({ const querySchema = z.strictObject({
subdomain: z.string(), subdomain: z.string()
// orgId: build === "saas" ? z.string() : z.string().optional() // Required for saas, optional otherwise
}); });
registry.registerPath({ registry.registerPath({
@@ -62,23 +58,6 @@ export async function checkDomainNamespaceAvailability(
} }
const { subdomain } = parsedQuery.data; const { subdomain } = parsedQuery.data;
// if (
// build == "saas" &&
// !isSubscribed(orgId!, tierMatrix.domainNamespaces)
// ) {
// // return not available
// return response<CheckDomainAvailabilityResponse>(res, {
// data: {
// available: false,
// options: []
// },
// success: true,
// error: false,
// message: "Your current subscription does not support custom domain namespaces. Please upgrade to access this feature.",
// status: HttpCode.OK
// });
// }
const namespaces = await db.select().from(domainNamespaces); const namespaces = await db.select().from(domainNamespaces);
let possibleDomains = namespaces.map((ns) => { let possibleDomains = namespaces.map((ns) => {
const desired = `${subdomain}.${ns.domainNamespaceId}`; const desired = `${subdomain}.${ns.domainNamespaceId}`;

View File

@@ -22,9 +22,6 @@ import { eq, sql } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { isSubscribed } from "#private/lib/isSubscribed";
import { build } from "@server/build";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
const paramsSchema = z.strictObject({}); const paramsSchema = z.strictObject({});
@@ -40,8 +37,7 @@ const querySchema = z.strictObject({
.optional() .optional()
.default("0") .default("0")
.transform(Number) .transform(Number)
.pipe(z.int().nonnegative()), .pipe(z.int().nonnegative())
// orgId: build === "saas" ? z.string() : z.string().optional() // Required for saas, optional otherwise
}); });
async function query(limit: number, offset: number) { async function query(limit: number, offset: number) {
@@ -103,26 +99,6 @@ export async function listDomainNamespaces(
); );
} }
// if (
// build == "saas" &&
// !isSubscribed(orgId!, tierMatrix.domainNamespaces)
// ) {
// return response<ListDomainNamespacesResponse>(res, {
// data: {
// domainNamespaces: [],
// pagination: {
// total: 0,
// limit,
// offset
// }
// },
// success: true,
// error: false,
// message: "No namespaces found. Your current subscription does not support custom domain namespaces. Please upgrade to access this feature.",
// status: HttpCode.OK
// });
// }
const domainNamespacesList = await query(limit, offset); const domainNamespacesList = await query(limit, offset);
const [{ count }] = await db const [{ count }] = await db

View File

@@ -21,7 +21,7 @@ import {
roles, roles,
roundTripMessageTracker, roundTripMessageTracker,
siteResources, siteResources,
siteNetworks, sites,
userOrgs userOrgs
} from "@server/db"; } from "@server/db";
import { logAccessAudit } from "#private/lib/logAccessAudit"; import { logAccessAudit } from "#private/lib/logAccessAudit";
@@ -63,12 +63,10 @@ const bodySchema = z
export type SignSshKeyResponse = { export type SignSshKeyResponse = {
certificate: string; certificate: string;
messageIds: number[];
messageId: number; messageId: number;
sshUsername: string; sshUsername: string;
sshHost: string; sshHost: string;
resourceId: number; resourceId: number;
siteIds: number[];
siteId: number; siteId: number;
keyId: string; keyId: string;
validPrincipals: string[]; validPrincipals: string[];
@@ -262,7 +260,10 @@ export async function signSshKey(
.update(userOrgs) .update(userOrgs)
.set({ pamUsername: usernameToUse }) .set({ pamUsername: usernameToUse })
.where( .where(
and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId)) and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.userId, userId)
)
); );
} else { } else {
usernameToUse = userOrg.pamUsername; usernameToUse = userOrg.pamUsername;
@@ -394,12 +395,21 @@ export async function signSshKey(
homedir = roleRows[0].sshCreateHomeDir ?? null; homedir = roleRows[0].sshCreateHomeDir ?? null;
} }
const sites = await db // get the site
.select({ siteId: siteNetworks.siteId }) const [newt] = await db
.from(siteNetworks) .select()
.where(eq(siteNetworks.networkId, resource.networkId!)); .from(newts)
.where(eq(newts.siteId, resource.siteId))
.limit(1);
const siteIds = sites.map((site) => site.siteId); if (!newt) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Site associated with resource not found"
)
);
}
// Sign the public key // Sign the public key
const now = BigInt(Math.floor(Date.now() / 1000)); const now = BigInt(Math.floor(Date.now() / 1000));
@@ -413,65 +423,44 @@ export async function signSshKey(
validBefore: now + validFor validBefore: now + validFor
}); });
const messageIds: number[] = []; const [message] = await db
for (const siteId of siteIds) { .insert(roundTripMessageTracker)
// get the site .values({
const [newt] = await db wsClientId: newt.newtId,
.select() messageType: `newt/pam/connection`,
.from(newts) sentAt: Math.floor(Date.now() / 1000)
.where(eq(newts.siteId, siteId)) })
.limit(1); .returning();
if (!newt) { if (!message) {
return next( return next(
createHttpError( createHttpError(
HttpCode.INTERNAL_SERVER_ERROR, HttpCode.INTERNAL_SERVER_ERROR,
"Site associated with resource not found" "Failed to create message tracker entry"
) )
); );
}
const [message] = await db
.insert(roundTripMessageTracker)
.values({
wsClientId: newt.newtId,
messageType: `newt/pam/connection`,
sentAt: Math.floor(Date.now() / 1000)
})
.returning();
if (!message) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create message tracker entry"
)
);
}
messageIds.push(message.messageId);
await sendToClient(newt.newtId, {
type: `newt/pam/connection`,
data: {
messageId: message.messageId,
orgId: orgId,
agentPort: resource.authDaemonPort ?? 22123,
externalAuthDaemon: resource.authDaemonMode === "remote",
agentHost: resource.destination,
caCert: caKeys.publicKeyOpenSSH,
username: usernameToUse,
niceId: resource.niceId,
metadata: {
sudoMode: sudoMode,
sudoCommands: parsedSudoCommands,
homedir: homedir,
groups: parsedGroups
}
}
});
} }
await sendToClient(newt.newtId, {
type: `newt/pam/connection`,
data: {
messageId: message.messageId,
orgId: orgId,
agentPort: resource.authDaemonPort ?? 22123,
externalAuthDaemon: resource.authDaemonMode === "remote",
agentHost: resource.destination,
caCert: caKeys.publicKeyOpenSSH,
username: usernameToUse,
niceId: resource.niceId,
metadata: {
sudoMode: sudoMode,
sudoCommands: parsedSudoCommands,
homedir: homedir,
groups: parsedGroups
}
}
});
const expiresIn = Number(validFor); // seconds const expiresIn = Number(validFor); // seconds
let sshHost; let sshHost;
@@ -491,7 +480,7 @@ export async function signSshKey(
metadata: JSON.stringify({ metadata: JSON.stringify({
resourceId: resource.siteResourceId, resourceId: resource.siteResourceId,
resource: resource.name, resource: resource.name,
siteIds: siteIds siteId: resource.siteId,
}) })
}); });
@@ -516,13 +505,11 @@ export async function signSshKey(
return response<SignSshKeyResponse>(res, { return response<SignSshKeyResponse>(res, {
data: { data: {
certificate: cert.certificate, certificate: cert.certificate,
messageIds: messageIds, messageId: message.messageId,
messageId: messageIds[0], // just pick the first one for backward compatibility
sshUsername: usernameToUse, sshUsername: usernameToUse,
sshHost: sshHost, sshHost: sshHost,
resourceId: resource.siteResourceId, resourceId: resource.siteResourceId,
siteIds: siteIds, siteId: resource.siteId,
siteId: siteIds[0], // just pick the first one for backward compatibility
keyId: cert.keyId, keyId: cert.keyId,
validPrincipals: cert.validPrincipals, validPrincipals: cert.validPrincipals,
validAfter: cert.validAfter.toISOString(), validAfter: cert.validAfter.toISOString(),

View File

@@ -1,4 +1,5 @@
import { import {
clientBandwidth,
clients, clients,
clientSitesAssociationsCache, clientSitesAssociationsCache,
currentFingerprint, currentFingerprint,
@@ -180,8 +181,8 @@ function queryClientsBase() {
name: clients.name, name: clients.name,
pubKey: clients.pubKey, pubKey: clients.pubKey,
subnet: clients.subnet, subnet: clients.subnet,
megabytesIn: clients.megabytesIn, megabytesIn: clientBandwidth.megabytesIn,
megabytesOut: clients.megabytesOut, megabytesOut: clientBandwidth.megabytesOut,
orgName: orgs.name, orgName: orgs.name,
type: clients.type, type: clients.type,
online: clients.online, online: clients.online,
@@ -200,7 +201,8 @@ function queryClientsBase() {
.leftJoin(orgs, eq(clients.orgId, orgs.orgId)) .leftJoin(orgs, eq(clients.orgId, orgs.orgId))
.leftJoin(olms, eq(clients.clientId, olms.clientId)) .leftJoin(olms, eq(clients.clientId, olms.clientId))
.leftJoin(users, eq(clients.userId, users.userId)) .leftJoin(users, eq(clients.userId, users.userId))
.leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId)); .leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId))
.leftJoin(clientBandwidth, eq(clientBandwidth.clientId, clients.clientId));
} }
async function getSiteAssociations(clientIds: number[]) { async function getSiteAssociations(clientIds: number[]) {
@@ -367,9 +369,15 @@ export async function listClients(
.offset(pageSize * (page - 1)) .offset(pageSize * (page - 1))
.orderBy( .orderBy(
sort_by sort_by
? order === "asc" ? (() => {
? asc(clients[sort_by]) const field =
: desc(clients[sort_by]) sort_by === "megabytesIn"
? clientBandwidth.megabytesIn
: sort_by === "megabytesOut"
? clientBandwidth.megabytesOut
: clients.name;
return order === "asc" ? asc(field) : desc(field);
})()
: asc(clients.name) : asc(clients.name)
); );

View File

@@ -1,5 +1,6 @@
import { build } from "@server/build"; import { build } from "@server/build";
import { import {
clientBandwidth,
clients, clients,
currentFingerprint, currentFingerprint,
db, db,
@@ -211,8 +212,8 @@ function queryUserDevicesBase() {
name: clients.name, name: clients.name,
pubKey: clients.pubKey, pubKey: clients.pubKey,
subnet: clients.subnet, subnet: clients.subnet,
megabytesIn: clients.megabytesIn, megabytesIn: clientBandwidth.megabytesIn,
megabytesOut: clients.megabytesOut, megabytesOut: clientBandwidth.megabytesOut,
orgName: orgs.name, orgName: orgs.name,
type: clients.type, type: clients.type,
online: clients.online, online: clients.online,
@@ -239,7 +240,8 @@ function queryUserDevicesBase() {
.leftJoin(orgs, eq(clients.orgId, orgs.orgId)) .leftJoin(orgs, eq(clients.orgId, orgs.orgId))
.leftJoin(olms, eq(clients.clientId, olms.clientId)) .leftJoin(olms, eq(clients.clientId, olms.clientId))
.leftJoin(users, eq(clients.userId, users.userId)) .leftJoin(users, eq(clients.userId, users.userId))
.leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId)); .leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId))
.leftJoin(clientBandwidth, eq(clientBandwidth.clientId, clients.clientId));
} }
type OlmWithUpdateAvailable = Awaited< type OlmWithUpdateAvailable = Awaited<
@@ -427,9 +429,15 @@ export async function listUserDevices(
.offset(pageSize * (page - 1)) .offset(pageSize * (page - 1))
.orderBy( .orderBy(
sort_by sort_by
? order === "asc" ? (() => {
? asc(clients[sort_by]) const field =
: desc(clients[sort_by]) sort_by === "megabytesIn"
? clientBandwidth.megabytesIn
: sort_by === "megabytesOut"
? clientBandwidth.megabytesOut
: clients.name;
return order === "asc" ? asc(field) : desc(field);
})()
: asc(clients.clientId) : asc(clients.clientId)
); );

View File

@@ -122,7 +122,7 @@ export async function flushSiteBandwidthToDb(): Promise<void> {
const snapshot = accumulator; const snapshot = accumulator;
accumulator = new Map<string, AccumulatorEntry>(); accumulator = new Map<string, AccumulatorEntry>();
const currentTime = new Date().toISOString(); const currentEpoch = Math.floor(Date.now() / 1000);
// Sort by publicKey for consistent lock ordering across concurrent // Sort by publicKey for consistent lock ordering across concurrent
// writers — deadlock-prevention strategy. // writers — deadlock-prevention strategy.
@@ -157,33 +157,52 @@ export async function flushSiteBandwidthToDb(): Promise<void> {
orgId: string; orgId: string;
pubKey: string; pubKey: string;
}>(sql` }>(sql`
UPDATE sites WITH upsert AS (
SET INSERT INTO "siteBandwidth" ("siteId", "bytesIn", "bytesOut", "lastBandwidthUpdate")
"bytesOut" = COALESCE("bytesOut", 0) + ${bytesIn}, SELECT s."siteId", ${bytesIn}, ${bytesOut}, ${currentEpoch}
"bytesIn" = COALESCE("bytesIn", 0) + ${bytesOut}, FROM "sites" s WHERE s."pubKey" = ${publicKey}
"lastBandwidthUpdate" = ${currentTime} ON CONFLICT ("siteId") DO UPDATE SET
WHERE "pubKey" = ${publicKey} "bytesIn" = COALESCE("siteBandwidth"."bytesIn", 0) + EXCLUDED."bytesIn",
RETURNING "orgId", "pubKey" "bytesOut" = COALESCE("siteBandwidth"."bytesOut", 0) + EXCLUDED."bytesOut",
"lastBandwidthUpdate" = EXCLUDED."lastBandwidthUpdate"
RETURNING "siteId"
)
SELECT u."siteId", s."orgId", s."pubKey"
FROM upsert u
INNER JOIN "sites" s ON s."siteId" = u."siteId"
`); `);
results.push(...result); results.push(...result);
} }
return results; return results;
} }
// PostgreSQL: batch UPDATE … FROM (VALUES …) — single round-trip per chunk. // PostgreSQL: batch UPSERT via CTE — single round-trip per chunk.
const valuesList = chunk.map(([publicKey, { bytesIn, bytesOut }]) => const valuesList = chunk.map(([publicKey, { bytesIn, bytesOut }]) =>
sql`(${publicKey}::text, ${bytesIn}::real, ${bytesOut}::real)` sql`(${publicKey}::text, ${bytesIn}::real, ${bytesOut}::real)`
); );
const valuesClause = sql.join(valuesList, sql`, `); const valuesClause = sql.join(valuesList, sql`, `);
return dbQueryRows<{ orgId: string; pubKey: string }>(sql` return dbQueryRows<{ orgId: string; pubKey: string }>(sql`
UPDATE sites WITH vals(pub_key, bytes_in, bytes_out) AS (
SET VALUES ${valuesClause}
"bytesOut" = COALESCE("bytesOut", 0) + v.bytes_in, ),
"bytesIn" = COALESCE("bytesIn", 0) + v.bytes_out, site_lookup AS (
"lastBandwidthUpdate" = ${currentTime} SELECT s."siteId", s."orgId", s."pubKey", v.bytes_in, v.bytes_out
FROM (VALUES ${valuesClause}) AS v(pub_key, bytes_in, bytes_out) FROM vals v
WHERE sites."pubKey" = v.pub_key INNER JOIN "sites" s ON s."pubKey" = v.pub_key
RETURNING sites."orgId" AS "orgId", sites."pubKey" AS "pubKey" ),
upsert AS (
INSERT INTO "siteBandwidth" ("siteId", "bytesIn", "bytesOut", "lastBandwidthUpdate")
SELECT sl."siteId", sl.bytes_in, sl.bytes_out, ${currentEpoch}::integer
FROM site_lookup sl
ON CONFLICT ("siteId") DO UPDATE SET
"bytesIn" = COALESCE("siteBandwidth"."bytesIn", 0) + EXCLUDED."bytesIn",
"bytesOut" = COALESCE("siteBandwidth"."bytesOut", 0) + EXCLUDED."bytesOut",
"lastBandwidthUpdate" = EXCLUDED."lastBandwidthUpdate"
RETURNING "siteId"
)
SELECT u."siteId", s."orgId", s."pubKey"
FROM upsert u
INNER JOIN "sites" s ON s."siteId" = u."siteId"
`); `);
}, `flush bandwidth chunk [${i}${chunkEnd}]`); }, `flush bandwidth chunk [${i}${chunkEnd}]`);
} catch (error) { } catch (error) {

View File

@@ -4,10 +4,8 @@ import {
clientSitesAssociationsCache, clientSitesAssociationsCache,
db, db,
ExitNode, ExitNode,
networks,
resources, resources,
Site, Site,
siteNetworks,
siteResources, siteResources,
targetHealthCheck, targetHealthCheck,
targets targets
@@ -139,14 +137,11 @@ export async function buildClientConfigurationForNewtClient(
// Filter out any null values from peers that didn't have an olm // Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null); const validPeers = peers.filter((peer) => peer !== null);
// Get all enabled site resources for this site by joining through siteNetworks and networks // Get all enabled site resources for this site
const allSiteResources = await db const allSiteResources = await db
.select() .select()
.from(siteResources) .from(siteResources)
.innerJoin(networks, eq(siteResources.networkId, networks.networkId)) .where(eq(siteResources.siteId, siteId));
.innerJoin(siteNetworks, eq(networks.networkId, siteNetworks.networkId))
.where(eq(siteNetworks.siteId, siteId))
.then((rows) => rows.map((r) => r.siteResources));
const targetsToSend: SubnetProxyTargetV2[] = []; const targetsToSend: SubnetProxyTargetV2[] = [];

View File

@@ -1,11 +1,11 @@
import { db, newts, sites, targetHealthCheck, targets } from "@server/db"; import { db, newts, sites, targetHealthCheck, targets, sitePing, siteBandwidth } from "@server/db";
import { import {
hasActiveConnections, hasActiveConnections,
getClientConfigVersion getClientConfigVersion
} from "#dynamic/routers/ws"; } from "#dynamic/routers/ws";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { Newt } from "@server/db"; import { Newt } from "@server/db";
import { eq, lt, isNull, and, or, ne, not } from "drizzle-orm"; import { eq, lt, isNull, and, or, ne } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
import { sendNewtSyncMessage } from "./sync"; import { sendNewtSyncMessage } from "./sync";
import { recordPing } from "./pingAccumulator"; import { recordPing } from "./pingAccumulator";
@@ -41,17 +41,18 @@ export const startNewtOfflineChecker = (): void => {
.select({ .select({
siteId: sites.siteId, siteId: sites.siteId,
newtId: newts.newtId, newtId: newts.newtId,
lastPing: sites.lastPing lastPing: sitePing.lastPing
}) })
.from(sites) .from(sites)
.innerJoin(newts, eq(newts.siteId, sites.siteId)) .innerJoin(newts, eq(newts.siteId, sites.siteId))
.leftJoin(sitePing, eq(sitePing.siteId, sites.siteId))
.where( .where(
and( and(
eq(sites.online, true), eq(sites.online, true),
eq(sites.type, "newt"), eq(sites.type, "newt"),
or( or(
lt(sites.lastPing, twoMinutesAgo), lt(sitePing.lastPing, twoMinutesAgo),
isNull(sites.lastPing) isNull(sitePing.lastPing)
) )
) )
); );
@@ -112,15 +113,11 @@ export const startNewtOfflineChecker = (): void => {
.select({ .select({
siteId: sites.siteId, siteId: sites.siteId,
online: sites.online, online: sites.online,
lastBandwidthUpdate: sites.lastBandwidthUpdate lastBandwidthUpdate: siteBandwidth.lastBandwidthUpdate
}) })
.from(sites) .from(sites)
.where( .innerJoin(siteBandwidth, eq(siteBandwidth.siteId, sites.siteId))
and( .where(eq(sites.type, "wireguard"));
eq(sites.type, "wireguard"),
not(isNull(sites.lastBandwidthUpdate))
)
);
const wireguardOfflineThreshold = Math.floor( const wireguardOfflineThreshold = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_BANDWIDTH_MS) / 1000 (Date.now() - OFFLINE_THRESHOLD_BANDWIDTH_MS) / 1000
@@ -128,12 +125,7 @@ export const startNewtOfflineChecker = (): void => {
// loop over each one. If its offline and there is a new update then mark it online. If its online and there is no update then mark it offline // loop over each one. If its offline and there is a new update then mark it online. If its online and there is no update then mark it offline
for (const site of allWireguardSites) { for (const site of allWireguardSites) {
const lastBandwidthUpdate = if ((site.lastBandwidthUpdate ?? 0) < wireguardOfflineThreshold && site.online) {
new Date(site.lastBandwidthUpdate!).getTime() / 1000;
if (
lastBandwidthUpdate < wireguardOfflineThreshold &&
site.online
) {
logger.info( logger.info(
`Marking wireguard site ${site.siteId} offline: no bandwidth update in over ${OFFLINE_THRESHOLD_BANDWIDTH_MS / 60000} minutes` `Marking wireguard site ${site.siteId} offline: no bandwidth update in over ${OFFLINE_THRESHOLD_BANDWIDTH_MS / 60000} minutes`
); );
@@ -142,10 +134,7 @@ export const startNewtOfflineChecker = (): void => {
.update(sites) .update(sites)
.set({ online: false }) .set({ online: false })
.where(eq(sites.siteId, site.siteId)); .where(eq(sites.siteId, site.siteId));
} else if ( } else if ((site.lastBandwidthUpdate ?? 0) >= wireguardOfflineThreshold && !site.online) {
lastBandwidthUpdate >= wireguardOfflineThreshold &&
!site.online
) {
logger.info( logger.info(
`Marking wireguard site ${site.siteId} online: recent bandwidth update` `Marking wireguard site ${site.siteId} online: recent bandwidth update`
); );

View File

@@ -1,6 +1,5 @@
import { db } from "@server/db"; import { db, clients, clientBandwidth } from "@server/db";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { clients } from "@server/db";
import { eq, sql } from "drizzle-orm"; import { eq, sql } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
@@ -85,7 +84,7 @@ export async function flushBandwidthToDb(): Promise<void> {
const snapshot = accumulator; const snapshot = accumulator;
accumulator = new Map<string, BandwidthAccumulator>(); accumulator = new Map<string, BandwidthAccumulator>();
const currentTime = new Date().toISOString(); const currentEpoch = Math.floor(Date.now() / 1000);
// Sort by publicKey for consistent lock ordering across concurrent // Sort by publicKey for consistent lock ordering across concurrent
// writers — this is the same deadlock-prevention strategy used in the // writers — this is the same deadlock-prevention strategy used in the
@@ -101,19 +100,37 @@ export async function flushBandwidthToDb(): Promise<void> {
for (const [publicKey, { bytesIn, bytesOut }] of sortedEntries) { for (const [publicKey, { bytesIn, bytesOut }] of sortedEntries) {
try { try {
await withDeadlockRetry(async () => { await withDeadlockRetry(async () => {
// Use atomic SQL increment to avoid the SELECT-then-UPDATE // Find clientId by pubKey
// anti-pattern and the races it would introduce. const [clientRow] = await db
.select({ clientId: clients.clientId })
.from(clients)
.where(eq(clients.pubKey, publicKey))
.limit(1);
if (!clientRow) {
logger.warn(`No client found for pubKey ${publicKey}, skipping`);
return;
}
await db await db
.update(clients) .insert(clientBandwidth)
.set({ .values({
clientId: clientRow.clientId,
// Note: bytesIn from peer goes to megabytesOut (data // Note: bytesIn from peer goes to megabytesOut (data
// sent to client) and bytesOut from peer goes to // sent to client) and bytesOut from peer goes to
// megabytesIn (data received from client). // megabytesIn (data received from client).
megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`, megabytesOut: bytesIn,
megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`, megabytesIn: bytesOut,
lastBandwidthUpdate: currentTime lastBandwidthUpdate: currentEpoch
}) })
.where(eq(clients.pubKey, publicKey)); .onConflictDoUpdate({
target: clientBandwidth.clientId,
set: {
megabytesOut: sql`COALESCE(${clientBandwidth.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${clientBandwidth.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentEpoch
}
});
}, `flush bandwidth for client ${publicKey}`); }, `flush bandwidth for client ${publicKey}`);
} catch (error) { } catch (error) {
logger.error( logger.error(

View File

@@ -1,6 +1,6 @@
import { db } from "@server/db"; import { db } from "@server/db";
import { sites, clients, olms } from "@server/db"; import { sites, clients, olms, sitePing, clientPing } from "@server/db";
import { inArray } from "drizzle-orm"; import { inArray, sql } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
/** /**
@@ -81,11 +81,8 @@ export function recordClientPing(
/** /**
* Flush all accumulated site pings to the database. * Flush all accumulated site pings to the database.
* *
* Each batch of up to BATCH_SIZE rows is written with a **single** UPDATE * For each batch: first upserts individual per-site timestamps into
* statement. We use the maximum timestamp across the batch so that `lastPing` * `sitePing`, then bulk-updates `sites.online = true`.
* reflects the most recent ping seen for any site in the group. This avoids
* the multi-statement transaction that previously created additional
* row-lock ordering hazards.
*/ */
async function flushSitePingsToDb(): Promise<void> { async function flushSitePingsToDb(): Promise<void> {
if (pendingSitePings.size === 0) { if (pendingSitePings.size === 0) {
@@ -103,20 +100,25 @@ async function flushSitePingsToDb(): Promise<void> {
for (let i = 0; i < entries.length; i += BATCH_SIZE) { for (let i = 0; i < entries.length; i += BATCH_SIZE) {
const batch = entries.slice(i, i + BATCH_SIZE); const batch = entries.slice(i, i + BATCH_SIZE);
// Use the latest timestamp in the batch so that `lastPing` always
// moves forward. Using a single timestamp for the whole batch means
// we only ever need one UPDATE statement (no transaction).
const maxTimestamp = Math.max(...batch.map(([, ts]) => ts));
const siteIds = batch.map(([id]) => id); const siteIds = batch.map(([id]) => id);
try { try {
await withRetry(async () => { await withRetry(async () => {
const rows = batch.map(([siteId, ts]) => ({ siteId, lastPing: ts }));
// Step 1: Upsert ping timestamps into sitePing
await db
.insert(sitePing)
.values(rows)
.onConflictDoUpdate({
target: sitePing.siteId,
set: { lastPing: sql`excluded."lastPing"` }
});
// Step 2: Update online status on sites
await db await db
.update(sites) .update(sites)
.set({ .set({ online: true })
online: true,
lastPing: maxTimestamp
})
.where(inArray(sites.siteId, siteIds)); .where(inArray(sites.siteId, siteIds));
}, "flushSitePingsToDb"); }, "flushSitePingsToDb");
} catch (error) { } catch (error) {
@@ -139,7 +141,8 @@ async function flushSitePingsToDb(): Promise<void> {
/** /**
* Flush all accumulated client (OLM) pings to the database. * Flush all accumulated client (OLM) pings to the database.
* *
* Same single-UPDATE-per-batch approach as `flushSitePingsToDb`. * For each batch: first upserts individual per-client timestamps into
* `clientPing`, then bulk-updates `clients.online = true, archived = false`.
*/ */
async function flushClientPingsToDb(): Promise<void> { async function flushClientPingsToDb(): Promise<void> {
if (pendingClientPings.size === 0 && pendingOlmArchiveResets.size === 0) { if (pendingClientPings.size === 0 && pendingOlmArchiveResets.size === 0) {
@@ -161,18 +164,25 @@ async function flushClientPingsToDb(): Promise<void> {
for (let i = 0; i < entries.length; i += BATCH_SIZE) { for (let i = 0; i < entries.length; i += BATCH_SIZE) {
const batch = entries.slice(i, i + BATCH_SIZE); const batch = entries.slice(i, i + BATCH_SIZE);
const maxTimestamp = Math.max(...batch.map(([, ts]) => ts));
const clientIds = batch.map(([id]) => id); const clientIds = batch.map(([id]) => id);
try { try {
await withRetry(async () => { await withRetry(async () => {
const rows = batch.map(([clientId, ts]) => ({ clientId, lastPing: ts }));
// Step 1: Upsert ping timestamps into clientPing
await db
.insert(clientPing)
.values(rows)
.onConflictDoUpdate({
target: clientPing.clientId,
set: { lastPing: sql`excluded."lastPing"` }
});
// Step 2: Update online + unarchive on clients
await db await db
.update(clients) .update(clients)
.set({ .set({ online: true, archived: false })
lastPing: maxTimestamp,
online: true,
archived: false
})
.where(inArray(clients.clientId, clientIds)); .where(inArray(clients.clientId, clientIds));
}, "flushClientPingsToDb"); }, "flushClientPingsToDb");
} catch (error) { } catch (error) {

View File

@@ -4,8 +4,6 @@ import {
clientSitesAssociationsCache, clientSitesAssociationsCache,
db, db,
exitNodes, exitNodes,
networks,
siteNetworks,
siteResources, siteResources,
sites sites
} from "@server/db"; } from "@server/db";
@@ -61,17 +59,9 @@ export async function buildSiteConfigurationForOlmClient(
clientSiteResourcesAssociationsCache.siteResourceId clientSiteResourcesAssociationsCache.siteResourceId
) )
) )
.innerJoin(
networks,
eq(siteResources.networkId, networks.networkId)
)
.innerJoin(
siteNetworks,
eq(networks.networkId, siteNetworks.networkId)
)
.where( .where(
and( and(
eq(siteNetworks.siteId, site.siteId), eq(siteResources.siteId, site.siteId),
eq( eq(
clientSiteResourcesAssociationsCache.clientId, clientSiteResourcesAssociationsCache.clientId,
client.clientId client.clientId
@@ -79,7 +69,6 @@ export async function buildSiteConfigurationForOlmClient(
) )
); );
if (jitMode) { if (jitMode) {
// Add site configuration to the array // Add site configuration to the array
siteConfigurations.push({ siteConfigurations.push({

View File

@@ -1,8 +1,8 @@
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws"; import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
import { db } from "@server/db"; import { db } from "@server/db";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { clients, olms, Olm } from "@server/db"; import { clients, olms, Olm, clientPing } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm"; import { eq, lt, isNull, and, or, inArray } from "drizzle-orm";
import { recordClientPing } from "@server/routers/newt/pingAccumulator"; import { recordClientPing } from "@server/routers/newt/pingAccumulator";
import logger from "@server/logger"; import logger from "@server/logger";
import { validateSessionToken } from "@server/auth/sessions/app"; import { validateSessionToken } from "@server/auth/sessions/app";
@@ -37,21 +37,33 @@ export const startOlmOfflineChecker = (): void => {
// TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING // TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING
// Find clients that haven't pinged in the last 2 minutes and mark them as offline // Find clients that haven't pinged in the last 2 minutes and mark them as offline
const offlineClients = await db const staleClientRows = await db
.update(clients) .select({
.set({ online: false }) clientId: clients.clientId,
olmId: clients.olmId,
lastPing: clientPing.lastPing
})
.from(clients)
.leftJoin(clientPing, eq(clientPing.clientId, clients.clientId))
.where( .where(
and( and(
eq(clients.online, true), eq(clients.online, true),
or( or(
lt(clients.lastPing, twoMinutesAgo), lt(clientPing.lastPing, twoMinutesAgo),
isNull(clients.lastPing) isNull(clientPing.lastPing)
) )
) )
) );
.returning();
for (const offlineClient of offlineClients) { if (staleClientRows.length > 0) {
const staleClientIds = staleClientRows.map((c) => c.clientId);
await db
.update(clients)
.set({ online: false })
.where(inArray(clients.clientId, staleClientIds));
}
for (const offlineClient of staleClientRows) {
logger.info( logger.info(
`Kicking offline olm client ${offlineClient.clientId} due to inactivity` `Kicking offline olm client ${offlineClient.clientId} due to inactivity`
); );

View File

@@ -4,12 +4,10 @@ import {
db, db,
exitNodes, exitNodes,
Site, Site,
siteNetworks, siteResources
siteResources,
sites
} from "@server/db"; } from "@server/db";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { clients, Olm } from "@server/db"; import { clients, Olm, sites } from "@server/db";
import { and, eq, or } from "drizzle-orm"; import { and, eq, or } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
import { initPeerAddHandshake } from "./peers"; import { initPeerAddHandshake } from "./peers";
@@ -46,31 +44,20 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
const { siteId, resourceId, chainId } = message.data; const { siteId, resourceId, chainId } = message.data;
const sendCancel = async () => { let site: Site | null = null;
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: { chainId }
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
};
let sitesToProcess: Site[] = [];
if (siteId) { if (siteId) {
// get the site
const [siteRes] = await db const [siteRes] = await db
.select() .select()
.from(sites) .from(sites)
.where(eq(sites.siteId, siteId)) .where(eq(sites.siteId, siteId))
.limit(1); .limit(1);
if (siteRes) { if (siteRes) {
sitesToProcess = [siteRes]; site = siteRes;
} }
} else if (resourceId) { }
if (resourceId && !site) {
const resources = await db const resources = await db
.select() .select()
.from(siteResources) .from(siteResources)
@@ -85,17 +72,27 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
); );
if (!resources || resources.length === 0) { if (!resources || resources.length === 0) {
logger.error( logger.error(`handleOlmServerPeerAddMessage: Resource not found`);
`handleOlmServerInitAddPeerHandshake: Resource not found` // cancel the request from the olm side to not keep doing this
); await sendToClient(
await sendCancel(); olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return; return;
} }
if (resources.length > 1) { if (resources.length > 1) {
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches // error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
logger.error( logger.error(
`handleOlmServerInitAddPeerHandshake: Multiple resources found matching the criteria` `handleOlmServerPeerAddMessage: Multiple resources found matching the criteria`
); );
return; return;
} }
@@ -120,120 +117,125 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
if (currentResourceAssociationCaches.length === 0) { if (currentResourceAssociationCaches.length === 0) {
logger.error( logger.error(
`handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}` `handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
); );
await sendCancel(); // cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return; return;
} }
if (!resource.networkId) { const siteIdFromResource = resource.siteId;
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteIdFromResource));
if (!siteRes) {
logger.error( logger.error(
`handleOlmServerInitAddPeerHandshake: Resource ${resource.siteResourceId} has no network` `handleOlmServerPeerAddMessage: Site with ID ${site} not found`
); );
await sendCancel();
return; return;
} }
// Get all sites associated with this resource's network via siteNetworks site = siteRes;
const siteRows = await db
.select({ siteId: siteNetworks.siteId })
.from(siteNetworks)
.where(eq(siteNetworks.networkId, resource.networkId));
if (!siteRows || siteRows.length === 0) {
logger.error(
`handleOlmServerInitAddPeerHandshake: No sites found for resource ${resource.siteResourceId}`
);
await sendCancel();
return;
}
// Fetch full site objects for all network members
const foundSites = await Promise.all(
siteRows.map(async ({ siteId: sid }) => {
const [s] = await db
.select()
.from(sites)
.where(eq(sites.siteId, sid))
.limit(1);
return s ?? null;
})
);
sitesToProcess = foundSites.filter((s): s is Site => s !== null);
} }
if (sitesToProcess.length === 0) { if (!site) {
logger.error( logger.error(`handleOlmServerPeerAddMessage: Site not found`);
`handleOlmServerInitAddPeerHandshake: No sites to process`
);
await sendCancel();
return; return;
} }
let handshakeInitiated = false; // check if the client can access this site using the cache
const currentSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
);
for (const site of sitesToProcess) { if (currentSiteAssociationCaches.length === 0) {
// Check if the client can access this site using the cache logger.error(
const currentSiteAssociationCaches = await db `handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}`
.select() );
.from(clientSitesAssociationsCache) // cancel the request from the olm side to not keep doing this
.where( await sendToClient(
and( olm.olmId,
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
);
if (currentSiteAssociationCaches.length === 0) {
logger.warn(
`handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to site ${site.siteId}, skipping`
);
continue;
}
if (!site.exitNodeId) {
logger.error(
`handleOlmServerInitAddPeerHandshake: Site ${site.siteId} has no exit node, skipping`
);
continue;
}
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
if (!exitNode) {
logger.error(
`handleOlmServerInitAddPeerHandshake: Exit node not found for site ${site.siteId}, skipping`
);
continue;
}
// Trigger the peer add handshake — if the peer was already added this will be a no-op
await initPeerAddHandshake(
client.clientId,
{ {
siteId: site.siteId, type: "olm/wg/peer/chain/cancel",
exitNode: { data: {
publicKey: exitNode.publicKey, chainId
endpoint: exitNode.endpoint
} }
}, },
olm.olmId, { incrementConfigVersion: false }
chainId ).catch((error) => {
); logger.warn(`Error sending message:`, error);
});
handshakeInitiated = true; return;
} }
if (!handshakeInitiated) { if (!site.exitNodeId) {
logger.error( logger.error(
`handleOlmServerInitAddPeerHandshake: No accessible sites with valid exit nodes found, cancelling chain` `handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
); );
await sendCancel(); // cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
} }
// get the exit node from the side
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
if (!exitNode) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
);
return;
}
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
{
siteId: site.siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
},
olm.olmId,
chainId
);
return; return;
}; };

View File

@@ -1,25 +1,43 @@
import { import {
Client,
clientSiteResourcesAssociationsCache, clientSiteResourcesAssociationsCache,
db, db,
networks, ExitNode,
siteNetworks, Org,
orgs,
roleClients,
roles,
siteResources, siteResources,
Transaction,
userClients,
userOrgs,
users
} from "@server/db"; } from "@server/db";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { import {
clients, clients,
clientSitesAssociationsCache, clientSitesAssociationsCache,
exitNodes,
Olm, Olm,
olms,
sites sites
} from "@server/db"; } from "@server/db";
import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm"; import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger"; import logger from "@server/logger";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import { import {
generateAliasConfig, generateAliasConfig,
getNextAvailableClientSubnet
} from "@server/lib/ip"; } from "@server/lib/ip";
import { generateRemoteSubnets } from "@server/lib/ip"; import { generateRemoteSubnets } from "@server/lib/ip";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import { import {
addPeer as newtAddPeer, addPeer as newtAddPeer,
deletePeer as newtDeletePeer
} from "@server/routers/newt/peers"; } from "@server/routers/newt/peers";
export const handleOlmServerPeerAddMessage: MessageHandler = async ( export const handleOlmServerPeerAddMessage: MessageHandler = async (
@@ -135,21 +153,13 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
clientSiteResourcesAssociationsCache.siteResourceId clientSiteResourcesAssociationsCache.siteResourceId
) )
) )
.innerJoin(
networks,
eq(siteResources.networkId, networks.networkId)
)
.innerJoin(
siteNetworks,
and(
eq(networks.networkId, siteNetworks.networkId),
eq(siteNetworks.siteId, site.siteId)
)
)
.where( .where(
eq( and(
clientSiteResourcesAssociationsCache.clientId, eq(siteResources.siteId, site.siteId),
client.clientId eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
) )
); );

View File

@@ -1,7 +1,7 @@
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, sites } from "@server/db"; import { db, sites, siteBandwidth } from "@server/db";
import { eq } from "drizzle-orm"; import { eq, inArray } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
@@ -60,12 +60,17 @@ export async function resetOrgBandwidth(
} }
await db await db
.update(sites) .update(siteBandwidth)
.set({ .set({
megabytesIn: 0, megabytesIn: 0,
megabytesOut: 0 megabytesOut: 0
}) })
.where(eq(sites.orgId, orgId)); .where(
inArray(
siteBandwidth.siteId,
db.select({ siteId: sites.siteId }).from(sites).where(eq(sites.orgId, orgId))
)
);
return response(res, { return response(res, {
data: {}, data: {},

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, domainNamespaces, loginPage } from "@server/db"; import { db, loginPage } from "@server/db";
import { import {
domains, domains,
orgDomains, orgDomains,
@@ -24,8 +24,6 @@ import { build } from "@server/build";
import { createCertificate } from "#dynamic/routers/certificates/createCertificate"; import { createCertificate } from "#dynamic/routers/certificates/createCertificate";
import { getUniqueResourceName } from "@server/db/names"; import { getUniqueResourceName } from "@server/db/names";
import { validateAndConstructDomain } from "@server/lib/domainUtils"; import { validateAndConstructDomain } from "@server/lib/domainUtils";
import { isSubscribed } from "#dynamic/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
const createResourceParamsSchema = z.strictObject({ const createResourceParamsSchema = z.strictObject({
orgId: z.string() orgId: z.string()
@@ -114,10 +112,7 @@ export async function createResource(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if ( if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) {
req.user &&
(!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)
) {
return next( return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role") createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
); );
@@ -198,29 +193,6 @@ async function createHttpResource(
const subdomain = parsedBody.data.subdomain; const subdomain = parsedBody.data.subdomain;
const stickySession = parsedBody.data.stickySession; const stickySession = parsedBody.data.stickySession;
if (build == "saas" && !isSubscribed(orgId!, tierMatrix.domainNamespaces)) {
// grandfather in existing users
const lastAllowedDate = new Date("2026-04-12");
const userCreatedDate = new Date(req.user?.dateCreated || new Date());
if (userCreatedDate > lastAllowedDate) {
// check if this domain id is a namespace domain and if so, reject
const domain = await db
.select()
.from(domainNamespaces)
.where(eq(domainNamespaces.domainId, domainId))
.limit(1);
if (domain.length > 0) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Your current subscription does not support custom domain namespaces. Please upgrade to access this feature."
)
);
}
}
}
// Validate domain and construct full domain // Validate domain and construct full domain
const domainResult = await validateAndConstructDomain( const domainResult = await validateAndConstructDomain(
domainId, domainId,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, domainNamespaces, loginPage } from "@server/db"; import { db, loginPage } from "@server/db";
import { import {
domains, domains,
Org, Org,
@@ -25,7 +25,6 @@ import { validateAndConstructDomain } from "@server/lib/domainUtils";
import { build } from "@server/build"; import { build } from "@server/build";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix"; import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { isSubscribed } from "#dynamic/lib/isSubscribed";
const updateResourceParamsSchema = z.strictObject({ const updateResourceParamsSchema = z.strictObject({
resourceId: z.string().transform(Number).pipe(z.int().positive()) resourceId: z.string().transform(Number).pipe(z.int().positive())
@@ -121,9 +120,7 @@ const updateHttpResourceBodySchema = z
if (data.headers) { if (data.headers) {
// HTTP header values must be visible ASCII or horizontal whitespace, no control chars (RFC 7230) // HTTP header values must be visible ASCII or horizontal whitespace, no control chars (RFC 7230)
const validHeaderValue = /^[\t\x20-\x7E]*$/; const validHeaderValue = /^[\t\x20-\x7E]*$/;
return data.headers.every((h) => return data.headers.every((h) => validHeaderValue.test(h.value));
validHeaderValue.test(h.value)
);
} }
return true; return true;
}, },
@@ -321,34 +318,6 @@ async function updateHttpResource(
if (updateData.domainId) { if (updateData.domainId) {
const domainId = updateData.domainId; const domainId = updateData.domainId;
if (
build == "saas" &&
!isSubscribed(resource.orgId, tierMatrix.domainNamespaces)
) {
// grandfather in existing users
const lastAllowedDate = new Date("2026-04-12");
const userCreatedDate = new Date(
req.user?.dateCreated || new Date()
);
if (userCreatedDate > lastAllowedDate) {
// check if this domain id is a namespace domain and if so, reject
const domain = await db
.select()
.from(domainNamespaces)
.where(eq(domainNamespaces.domainId, domainId))
.limit(1);
if (domain.length > 0) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Your current subscription does not support custom domain namespaces. Please upgrade to access this feature."
)
);
}
}
}
// Validate domain and construct full domain // Validate domain and construct full domain
const domainResult = await validateAndConstructDomain( const domainResult = await validateAndConstructDomain(
domainId, domainId,
@@ -397,7 +366,7 @@ async function updateHttpResource(
); );
} }
} }
if (build != "oss") { if (build != "oss") {
const existingLoginPages = await db const existingLoginPages = await db
.select() .select()

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, Site, siteNetworks, siteResources } from "@server/db"; import { db, Site, siteResources } from "@server/db";
import { newts, newtSessions, sites } from "@server/db"; import { newts, newtSessions, sites } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -71,23 +71,18 @@ export async function deleteSite(
await deletePeer(site.exitNodeId!, site.pubKey); await deletePeer(site.exitNodeId!, site.pubKey);
} }
} else if (site.type == "newt") { } else if (site.type == "newt") {
const networks = await trx // delete all of the site resources on this site
.select({ networkId: siteNetworks.networkId }) const siteResourcesOnSite = trx
.from(siteNetworks) .delete(siteResources)
.where(eq(siteNetworks.siteId, siteId)); .where(eq(siteResources.siteId, siteId))
.returning();
// loop through them // loop through them
for (const network of await networks) { for (const removedSiteResource of await siteResourcesOnSite) {
const [siteResource] = await trx await rebuildClientAssociationsFromSiteResource(
.select() removedSiteResource,
.from(siteResources) trx
.where(eq(siteResources.networkId, network.networkId)); );
if (siteResource) {
await rebuildClientAssociationsFromSiteResource(
siteResource,
trx
);
}
} }
// get the newt on the site by querying the newt table for siteId // get the newt on the site by querying the newt table for siteId

View File

@@ -6,6 +6,7 @@ import {
remoteExitNodes, remoteExitNodes,
roleSites, roleSites,
sites, sites,
siteBandwidth,
userSites userSites
} from "@server/db"; } from "@server/db";
import cache from "#dynamic/lib/cache"; import cache from "#dynamic/lib/cache";
@@ -155,8 +156,8 @@ function querySitesBase() {
name: sites.name, name: sites.name,
pubKey: sites.pubKey, pubKey: sites.pubKey,
subnet: sites.subnet, subnet: sites.subnet,
megabytesIn: sites.megabytesIn, megabytesIn: siteBandwidth.megabytesIn,
megabytesOut: sites.megabytesOut, megabytesOut: siteBandwidth.megabytesOut,
orgName: orgs.name, orgName: orgs.name,
type: sites.type, type: sites.type,
online: sites.online, online: sites.online,
@@ -175,7 +176,8 @@ function querySitesBase() {
.leftJoin( .leftJoin(
remoteExitNodes, remoteExitNodes,
eq(remoteExitNodes.exitNodeId, sites.exitNodeId) eq(remoteExitNodes.exitNodeId, sites.exitNodeId)
); )
.leftJoin(siteBandwidth, eq(siteBandwidth.siteId, sites.siteId));
} }
type SiteWithUpdateAvailable = Awaited<ReturnType<typeof querySitesBase>>[0] & { type SiteWithUpdateAvailable = Awaited<ReturnType<typeof querySitesBase>>[0] & {
@@ -299,9 +301,15 @@ export async function listSites(
.offset(pageSize * (page - 1)) .offset(pageSize * (page - 1))
.orderBy( .orderBy(
sort_by sort_by
? order === "asc" ? (() => {
? asc(sites[sort_by]) const field =
: desc(sites[sort_by]) sort_by === "megabytesIn"
? siteBandwidth.megabytesIn
: sort_by === "megabytesOut"
? siteBandwidth.megabytesOut
: sites.name;
return order === "asc" ? asc(field) : desc(field);
})()
: asc(sites.name) : asc(sites.name)
); );

View File

@@ -5,8 +5,6 @@ import {
orgs, orgs,
roles, roles,
roleSiteResources, roleSiteResources,
siteNetworks,
networks,
SiteResource, SiteResource,
siteResources, siteResources,
sites, sites,
@@ -25,7 +23,7 @@ import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { and, eq, inArray } from "drizzle-orm"; import { and, eq } from "drizzle-orm";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { z } from "zod"; import { z } from "zod";
@@ -39,7 +37,7 @@ const createSiteResourceSchema = z
.strictObject({ .strictObject({
name: z.string().min(1).max(255), name: z.string().min(1).max(255),
mode: z.enum(["host", "cidr", "port"]), mode: z.enum(["host", "cidr", "port"]),
siteIds: z.array(z.int()), siteId: z.int(),
// protocol: z.enum(["tcp", "udp"]).optional(), // protocol: z.enum(["tcp", "udp"]).optional(),
// proxyPort: z.int().positive().optional(), // proxyPort: z.int().positive().optional(),
// destinationPort: z.int().positive().optional(), // destinationPort: z.int().positive().optional(),
@@ -161,7 +159,7 @@ export async function createSiteResource(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
const { const {
name, name,
siteIds, siteId,
mode, mode,
// protocol, // protocol,
// proxyPort, // proxyPort,
@@ -180,16 +178,14 @@ export async function createSiteResource(
} = parsedBody.data; } = parsedBody.data;
// Verify the site exists and belongs to the org // Verify the site exists and belongs to the org
const sitesToAssign = await db const [site] = await db
.select() .select()
.from(sites) .from(sites)
.where(and(inArray(sites.siteId, siteIds), eq(sites.orgId, orgId))) .where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId)))
.limit(1); .limit(1);
if (sitesToAssign.length !== siteIds.length) { if (!site) {
return next( return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
createHttpError(HttpCode.NOT_FOUND, "Some site not found")
);
} }
const [org] = await db const [org] = await db
@@ -291,29 +287,12 @@ export async function createSiteResource(
let newSiteResource: SiteResource | undefined; let newSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const [network] = await trx
.insert(networks)
.values({
scope: "resource",
orgId: orgId
})
.returning();
if (!network) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Failed to create network`
)
);
}
// Create the site resource // Create the site resource
const insertValues: typeof siteResources.$inferInsert = { const insertValues: typeof siteResources.$inferInsert = {
siteId,
niceId, niceId,
orgId, orgId,
name, name,
networkId: network.networkId,
mode: mode as "host" | "cidr", mode: mode as "host" | "cidr",
destination, destination,
enabled, enabled,
@@ -338,13 +317,6 @@ export async function createSiteResource(
//////////////////// update the associations //////////////////// //////////////////// update the associations ////////////////////
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: network.networkId
});
}
const [adminRole] = await trx const [adminRole] = await trx
.select() .select()
.from(roles) .from(roles)
@@ -387,21 +359,16 @@ export async function createSiteResource(
); );
} }
for (const siteToAssign of sitesToAssign) { const [newt] = await trx
const [newt] = await trx .select()
.select() .from(newts)
.from(newts) .where(eq(newts.siteId, site.siteId))
.where(eq(newts.siteId, siteToAssign.siteId)) .limit(1);
.limit(1);
if (!newt) { if (!newt) {
return next( return next(
createHttpError( createHttpError(HttpCode.NOT_FOUND, "Newt not found")
HttpCode.NOT_FOUND, );
`Newt not found for site ${siteToAssign.siteId}`
)
);
}
} }
await rebuildClientAssociationsFromSiteResource( await rebuildClientAssociationsFromSiteResource(
@@ -420,7 +387,7 @@ export async function createSiteResource(
} }
logger.info( logger.info(
`Created site resource ${newSiteResource.siteResourceId} for org ${orgId}` `Created site resource ${newSiteResource.siteResourceId} for site ${siteId}`
); );
return response(res, { return response(res, {

View File

@@ -70,18 +70,17 @@ export async function deleteSiteResource(
.where(and(eq(siteResources.siteResourceId, siteResourceId))) .where(and(eq(siteResources.siteResourceId, siteResourceId)))
.returning(); .returning();
// not sure why this is here... const [newt] = await trx
// const [newt] = await trx .select()
// .select() .from(newts)
// .from(newts) .where(eq(newts.siteId, removedSiteResource.siteId))
// .where(eq(newts.siteId, removedSiteResource.siteId)) .limit(1);
// .limit(1);
// if (!newt) { if (!newt) {
// return next( return next(
// createHttpError(HttpCode.NOT_FOUND, "Newt not found") createHttpError(HttpCode.NOT_FOUND, "Newt not found")
// ); );
// } }
await rebuildClientAssociationsFromSiteResource( await rebuildClientAssociationsFromSiteResource(
removedSiteResource, removedSiteResource,

View File

@@ -17,34 +17,38 @@ const getSiteResourceParamsSchema = z.strictObject({
.transform((val) => (val ? Number(val) : undefined)) .transform((val) => (val ? Number(val) : undefined))
.pipe(z.int().positive().optional()) .pipe(z.int().positive().optional())
.optional(), .optional(),
siteId: z.string().transform(Number).pipe(z.int().positive()),
niceId: z.string().optional(), niceId: z.string().optional(),
orgId: z.string() orgId: z.string()
}); });
async function query( async function query(
siteResourceId?: number, siteResourceId?: number,
siteId?: number,
niceId?: string, niceId?: string,
orgId?: string orgId?: string
) { ) {
if (siteResourceId && orgId) { if (siteResourceId && siteId && orgId) {
const [siteResource] = await db const [siteResource] = await db
.select() .select()
.from(siteResources) .from(siteResources)
.where( .where(
and( and(
eq(siteResources.siteResourceId, siteResourceId), eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId) eq(siteResources.orgId, orgId)
) )
) )
.limit(1); .limit(1);
return siteResource; return siteResource;
} else if (niceId && orgId) { } else if (niceId && siteId && orgId) {
const [siteResource] = await db const [siteResource] = await db
.select() .select()
.from(siteResources) .from(siteResources)
.where( .where(
and( and(
eq(siteResources.niceId, niceId), eq(siteResources.niceId, niceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId) eq(siteResources.orgId, orgId)
) )
) )
@@ -80,6 +84,7 @@ registry.registerPath({
request: { request: {
params: z.object({ params: z.object({
niceId: z.string(), niceId: z.string(),
siteId: z.number(),
orgId: z.string() orgId: z.string()
}) })
}, },
@@ -102,10 +107,10 @@ export async function getSiteResource(
); );
} }
const { siteResourceId, niceId, orgId } = parsedParams.data; const { siteResourceId, siteId, niceId, orgId } = parsedParams.data;
// Get the site resource // Get the site resource
const siteResource = await query(siteResourceId, niceId, orgId); const siteResource = await query(siteResourceId, siteId, niceId, orgId);
if (!siteResource) { if (!siteResource) {
return next( return next(

View File

@@ -1,4 +1,4 @@
import { db, SiteResource, siteNetworks, siteResources, sites } from "@server/db"; import { db, SiteResource, siteResources, sites } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
@@ -73,10 +73,9 @@ const listAllSiteResourcesByOrgQuerySchema = z.object({
export type ListAllSiteResourcesByOrgResponse = PaginatedResponse<{ export type ListAllSiteResourcesByOrgResponse = PaginatedResponse<{
siteResources: (SiteResource & { siteResources: (SiteResource & {
siteIds: number[]; siteName: string;
siteNames: string[]; siteNiceId: string;
siteNiceIds: string[]; siteAddress: string | null;
siteAddresses: (string | null)[];
})[]; })[];
}>; }>;
@@ -84,6 +83,7 @@ function querySiteResourcesBase() {
return db return db
.select({ .select({
siteResourceId: siteResources.siteResourceId, siteResourceId: siteResources.siteResourceId,
siteId: siteResources.siteId,
orgId: siteResources.orgId, orgId: siteResources.orgId,
niceId: siteResources.niceId, niceId: siteResources.niceId,
name: siteResources.name, name: siteResources.name,
@@ -100,20 +100,14 @@ function querySiteResourcesBase() {
disableIcmp: siteResources.disableIcmp, disableIcmp: siteResources.disableIcmp,
authDaemonMode: siteResources.authDaemonMode, authDaemonMode: siteResources.authDaemonMode,
authDaemonPort: siteResources.authDaemonPort, authDaemonPort: siteResources.authDaemonPort,
networkId: siteResources.networkId, siteName: sites.name,
defaultNetworkId: siteResources.defaultNetworkId, siteNiceId: sites.niceId,
siteNames: sql<string[]>`array_agg(${sites.name})`, siteAddress: sites.address
siteNiceIds: sql<string[]>`array_agg(${sites.niceId})`,
siteIds: sql<number[]>`array_agg(${sites.siteId})`,
siteAddresses: sql<(string | null)[]>`array_agg(${sites.address})`
}) })
.from(siteResources) .from(siteResources)
.innerJoin(siteNetworks, eq(siteResources.networkId, siteNetworks.networkId)) .innerJoin(sites, eq(siteResources.siteId, sites.siteId));
.innerJoin(sites, eq(siteNetworks.siteId, sites.siteId))
.groupBy(siteResources.siteResourceId);
} }
registry.registerPath({ registry.registerPath({
method: "get", method: "get",
path: "/org/{orgId}/site-resources", path: "/org/{orgId}/site-resources",

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, networks, siteNetworks } from "@server/db"; import { db } from "@server/db";
import { siteResources, sites, SiteResource } from "@server/db"; import { siteResources, sites, SiteResource } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -108,21 +108,13 @@ export async function listSiteResources(
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
} }
// Get site resources by joining networks to siteResources via siteNetworks // Get site resources
const siteResourcesList = await db const siteResourcesList = await db
.select() .select()
.from(siteNetworks) .from(siteResources)
.innerJoin(
networks,
eq(siteNetworks.networkId, networks.networkId)
)
.innerJoin(
siteResources,
eq(siteResources.networkId, networks.networkId)
)
.where( .where(
and( and(
eq(siteNetworks.siteId, siteId), eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId) eq(siteResources.orgId, orgId)
) )
) )
@@ -136,7 +128,6 @@ export async function listSiteResources(
.limit(limit) .limit(limit)
.offset(offset); .offset(offset);
return response(res, { return response(res, {
data: { siteResources: siteResourcesList }, data: { siteResources: siteResourcesList },
success: true, success: true,

View File

@@ -7,18 +7,12 @@ import {
orgs, orgs,
roles, roles,
roleSiteResources, roleSiteResources,
siteNetworks,
SiteResource, SiteResource,
siteResources, siteResources,
sites, sites,
networks,
Transaction, Transaction,
userSiteResources userSiteResources
} from "@server/db"; } from "@server/db";
import response from "@server/lib/response";
import { eq, and, ne, inArray } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { updatePeerData, updateTargets } from "@server/routers/client/targets";
import { tierMatrix } from "@server/lib/billing/tierMatrix"; import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { import {
generateAliasConfig, generateAliasConfig,
@@ -28,8 +22,12 @@ import {
portRangeStringSchema portRangeStringSchema
} from "@server/lib/ip"; } from "@server/lib/ip";
import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations";
import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import { updatePeerData, updateTargets } from "@server/routers/client/targets";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { and, eq, ne } from "drizzle-orm";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { z } from "zod"; import { z } from "zod";
@@ -42,8 +40,7 @@ const updateSiteResourceParamsSchema = z.strictObject({
const updateSiteResourceSchema = z const updateSiteResourceSchema = z
.strictObject({ .strictObject({
name: z.string().min(1).max(255).optional(), name: z.string().min(1).max(255).optional(),
siteIds: z.array(z.int()), siteId: z.int(),
// niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(),
niceId: z niceId: z
.string() .string()
.min(1) .min(1)
@@ -175,7 +172,7 @@ export async function updateSiteResource(
const { siteResourceId } = parsedParams.data; const { siteResourceId } = parsedParams.data;
const { const {
name, name,
siteIds, // because it can change siteId, // because it can change
niceId, niceId,
mode, mode,
destination, destination,
@@ -191,6 +188,16 @@ export async function updateSiteResource(
authDaemonMode authDaemonMode
} = parsedBody.data; } = parsedBody.data;
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// Check if site resource exists // Check if site resource exists
const [existingSiteResource] = await db const [existingSiteResource] = await db
.select() .select()
@@ -230,24 +237,6 @@ export async function updateSiteResource(
); );
} }
// Verify the site exists and belongs to the org
const sitesToAssign = await db
.select()
.from(sites)
.where(
and(
inArray(sites.siteId, siteIds),
eq(sites.orgId, existingSiteResource.orgId)
)
)
.limit(1);
if (sitesToAssign.length !== siteIds.length) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Some site not found")
);
}
// Only check if destination is an IP address // Only check if destination is an IP address
const isIp = z const isIp = z
.union([z.ipv4(), z.ipv6()]) .union([z.ipv4(), z.ipv6()])
@@ -265,24 +254,25 @@ export async function updateSiteResource(
); );
} }
let sitesChanged = false; let existingSite = site;
const existingSiteIds = existingSiteResource.networkId let siteChanged = false;
? await db if (existingSiteResource.siteId !== siteId) {
.select() siteChanged = true;
.from(siteNetworks) // get the existing site
.where( [existingSite] = await db
eq(siteNetworks.networkId, existingSiteResource.networkId) .select()
) .from(sites)
: []; .where(eq(sites.siteId, existingSiteResource.siteId))
.limit(1);
const existingSiteIdSet = new Set(existingSiteIds.map((s) => s.siteId)); if (!existingSite) {
const newSiteIdSet = new Set(siteIds); return next(
createHttpError(
if ( HttpCode.NOT_FOUND,
existingSiteIdSet.size !== newSiteIdSet.size || "Existing site not found"
![...existingSiteIdSet].every((id) => newSiteIdSet.has(id)) )
) { );
sitesChanged = true; }
} }
// make sure the alias is unique within the org if provided // make sure the alias is unique within the org if provided
@@ -312,7 +302,7 @@ export async function updateSiteResource(
let updatedSiteResource: SiteResource | undefined; let updatedSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place // if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place
if (sitesChanged) { if (siteChanged) {
// delete the existing site resource // delete the existing site resource
await trx await trx
.delete(siteResources) .delete(siteResources)
@@ -353,6 +343,7 @@ export async function updateSiteResource(
.update(siteResources) .update(siteResources)
.set({ .set({
name, name,
siteId,
niceId, niceId,
mode, mode,
destination, destination,
@@ -456,6 +447,7 @@ export async function updateSiteResource(
.update(siteResources) .update(siteResources)
.set({ .set({
name: name, name: name,
siteId: siteId,
mode: mode, mode: mode,
destination: destination, destination: destination,
enabled: enabled, enabled: enabled,
@@ -472,23 +464,6 @@ export async function updateSiteResource(
//////////////////// update the associations //////////////////// //////////////////// update the associations ////////////////////
// delete the site - site resources associations
await trx
.delete(siteNetworks)
.where(
eq(
siteNetworks.networkId,
updatedSiteResource.networkId!
)
);
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: updatedSiteResource.networkId!
});
}
await trx await trx
.delete(clientSiteResources) .delete(clientSiteResources)
.where( .where(
@@ -558,15 +533,14 @@ export async function updateSiteResource(
); );
} }
logger.info(`Updated site resource ${siteResourceId}`); logger.info(
`Updated site resource ${siteResourceId} for site ${siteId}`
);
await handleMessagingForUpdatedSiteResource( await handleMessagingForUpdatedSiteResource(
existingSiteResource, existingSiteResource,
updatedSiteResource, updatedSiteResource,
siteIds.map((siteId) => ({ { siteId: site.siteId, orgId: site.orgId },
siteId,
orgId: existingSiteResource.orgId
})),
trx trx
); );
} }
@@ -593,7 +567,7 @@ export async function updateSiteResource(
export async function handleMessagingForUpdatedSiteResource( export async function handleMessagingForUpdatedSiteResource(
existingSiteResource: SiteResource | undefined, existingSiteResource: SiteResource | undefined,
updatedSiteResource: SiteResource, updatedSiteResource: SiteResource,
sites: { siteId: number; orgId: string }[], site: { siteId: number; orgId: string },
trx: Transaction trx: Transaction
) { ) {
logger.debug( logger.debug(
@@ -630,112 +604,105 @@ export async function handleMessagingForUpdatedSiteResource(
// if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all // if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all
if (destinationChanged || aliasChanged || portRangesChanged) { if (destinationChanged || aliasChanged || portRangesChanged) {
for (const site of sites) { const [newt] = await trx
const [newt] = await trx .select()
.select() .from(newts)
.from(newts) .where(eq(newts.siteId, site.siteId))
.where(eq(newts.siteId, site.siteId)) .limit(1);
.limit(1);
if (!newt) { if (!newt) {
throw new Error( throw new Error(
"Newt not found for site during site resource update" "Newt not found for site during site resource update"
); );
}
// Only update targets on newt if destination changed
if (destinationChanged || portRangesChanged) {
const oldTarget = generateSubnetProxyTargetV2(
existingSiteResource,
mergedAllClients
);
const newTarget = generateSubnetProxyTargetV2(
updatedSiteResource,
mergedAllClients
);
await updateTargets(
newt.newtId,
{
oldTargets: oldTarget ? [oldTarget] : [],
newTargets: newTarget ? [newTarget] : []
},
newt.version
);
}
const olmJobs: Promise<void>[] = [];
for (const client of mergedAllClients) {
// does this client have access to another resource on this site that has the same destination still? if so we dont want to remove it from their olm yet
// todo: optimize this query if needed
const oldDestinationStillInUseSites = await trx
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
siteResources.siteResourceId
)
)
.innerJoin(
siteNetworks,
eq(siteNetworks.networkId, siteResources.networkId)
)
.where(
and(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
),
eq(siteNetworks.siteId, site.siteId),
eq(
siteResources.destination,
existingSiteResource.destination
),
ne(
siteResources.siteResourceId,
existingSiteResource.siteResourceId
)
)
);
const oldDestinationStillInUseByASite =
oldDestinationStillInUseSites.length > 0;
// we also need to update the remote subnets on the olms for each client that has access to this site
olmJobs.push(
updatePeerData(
client.clientId,
site.siteId,
destinationChanged
? {
oldRemoteSubnets:
!oldDestinationStillInUseByASite
? generateRemoteSubnets([
existingSiteResource
])
: [],
newRemoteSubnets: generateRemoteSubnets([
updatedSiteResource
])
}
: undefined,
aliasChanged
? {
oldAliases: generateAliasConfig([
existingSiteResource
]),
newAliases: generateAliasConfig([
updatedSiteResource
])
}
: undefined
)
);
}
await Promise.all(olmJobs);
} }
// Only update targets on newt if destination changed
if (destinationChanged || portRangesChanged) {
const oldTarget = generateSubnetProxyTargetV2(
existingSiteResource,
mergedAllClients
);
const newTarget = generateSubnetProxyTargetV2(
updatedSiteResource,
mergedAllClients
);
await updateTargets(
newt.newtId,
{
oldTargets: oldTarget ? [oldTarget] : [],
newTargets: newTarget ? [newTarget] : []
},
newt.version
);
}
const olmJobs: Promise<void>[] = [];
for (const client of mergedAllClients) {
// does this client have access to another resource on this site that has the same destination still? if so we dont want to remove it from their olm yet
// todo: optimize this query if needed
const oldDestinationStillInUseSites = await trx
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
siteResources.siteResourceId
)
)
.where(
and(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
),
eq(siteResources.siteId, site.siteId),
eq(
siteResources.destination,
existingSiteResource.destination
),
ne(
siteResources.siteResourceId,
existingSiteResource.siteResourceId
)
)
);
const oldDestinationStillInUseByASite =
oldDestinationStillInUseSites.length > 0;
// we also need to update the remote subnets on the olms for each client that has access to this site
olmJobs.push(
updatePeerData(
client.clientId,
updatedSiteResource.siteId,
destinationChanged
? {
oldRemoteSubnets: !oldDestinationStillInUseByASite
? generateRemoteSubnets([
existingSiteResource
])
: [],
newRemoteSubnets: generateRemoteSubnets([
updatedSiteResource
])
}
: undefined,
aliasChanged
? {
oldAliases: generateAliasConfig([
existingSiteResource
]),
newAliases: generateAliasConfig([
updatedSiteResource
])
}
: undefined
)
);
}
await Promise.all(olmJobs);
} }
} }

View File

@@ -1,14 +1,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db } from "@server/db";
import { import { orgs, roles, userInviteRoles, userInvites, userOrgs, users } from "@server/db";
orgs,
roles,
userInviteRoles,
userInvites,
userOrgs,
users
} from "@server/db";
import { and, eq, inArray } from "drizzle-orm"; import { and, eq, inArray } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -44,7 +37,8 @@ const inviteUserBodySchema = z
regenerate: z.boolean().optional() regenerate: z.boolean().optional()
}) })
.refine( .refine(
(d) => (d.roleIds != null && d.roleIds.length > 0) || d.roleId != null, (d) =>
(d.roleIds != null && d.roleIds.length > 0) || d.roleId != null,
{ message: "roleIds or roleId is required", path: ["roleIds"] } { message: "roleIds or roleId is required", path: ["roleIds"] }
) )
.transform((data) => ({ .transform((data) => ({
@@ -271,7 +265,7 @@ export async function inviteUser(
) )
); );
const inviteLink = `${config.getRawConfig().app.dashboard_url}/invite?token=${inviteId}-${token}&email=${email}`; const inviteLink = `${config.getRawConfig().app.dashboard_url}/invite?token=${inviteId}-${token}&email=${encodeURIComponent(email)}`;
if (doEmail) { if (doEmail) {
await sendEmail( await sendEmail(
@@ -320,12 +314,12 @@ export async function inviteUser(
expiresAt, expiresAt,
tokenHash tokenHash
}); });
await trx await trx.insert(userInviteRoles).values(
.insert(userInviteRoles) uniqueRoleIds.map((roleId) => ({ inviteId, roleId }))
.values(uniqueRoleIds.map((roleId) => ({ inviteId, roleId }))); );
}); });
const inviteLink = `${config.getRawConfig().app.dashboard_url}/invite?token=${inviteId}-${token}&email=${email}`; const inviteLink = `${config.getRawConfig().app.dashboard_url}/invite?token=${inviteId}-${token}&email=${encodeURIComponent(email)}`;
if (doEmail) { if (doEmail) {
await sendEmail( await sendEmail(

View File

@@ -22,6 +22,7 @@ import m13 from "./scriptsPg/1.15.3";
import m14 from "./scriptsPg/1.15.4"; import m14 from "./scriptsPg/1.15.4";
import m15 from "./scriptsPg/1.16.0"; import m15 from "./scriptsPg/1.16.0";
import m16 from "./scriptsPg/1.17.0"; import m16 from "./scriptsPg/1.17.0";
import m17 from "./scriptsPg/1.18.0";
// THIS CANNOT IMPORT ANYTHING FROM THE SERVER // THIS CANNOT IMPORT ANYTHING FROM THE SERVER
// EXCEPT FOR THE DATABASE AND THE SCHEMA // EXCEPT FOR THE DATABASE AND THE SCHEMA
@@ -43,7 +44,8 @@ const migrations = [
{ version: "1.15.3", run: m13 }, { version: "1.15.3", run: m13 },
{ version: "1.15.4", run: m14 }, { version: "1.15.4", run: m14 },
{ version: "1.16.0", run: m15 }, { version: "1.16.0", run: m15 },
{ version: "1.17.0", run: m16 } { version: "1.17.0", run: m16 },
{ version: "1.18.0", run: m17 }
// Add new migrations here as they are created // Add new migrations here as they are created
] as { ] as {
version: string; version: string;

View File

@@ -40,6 +40,7 @@ import m34 from "./scriptsSqlite/1.15.3";
import m35 from "./scriptsSqlite/1.15.4"; import m35 from "./scriptsSqlite/1.15.4";
import m36 from "./scriptsSqlite/1.16.0"; import m36 from "./scriptsSqlite/1.16.0";
import m37 from "./scriptsSqlite/1.17.0"; import m37 from "./scriptsSqlite/1.17.0";
import m38 from "./scriptsSqlite/1.18.0";
// THIS CANNOT IMPORT ANYTHING FROM THE SERVER // THIS CANNOT IMPORT ANYTHING FROM THE SERVER
// EXCEPT FOR THE DATABASE AND THE SCHEMA // EXCEPT FOR THE DATABASE AND THE SCHEMA
@@ -77,7 +78,8 @@ const migrations = [
{ version: "1.15.3", run: m34 }, { version: "1.15.3", run: m34 },
{ version: "1.15.4", run: m35 }, { version: "1.15.4", run: m35 },
{ version: "1.16.0", run: m36 }, { version: "1.16.0", run: m36 },
{ version: "1.17.0", run: m37 } { version: "1.17.0", run: m37 },
{ version: "1.18.0", run: m38 }
// Add new migrations here as they are created // Add new migrations here as they are created
] as const; ] as const;

View File

@@ -235,9 +235,7 @@ export default async function migration() {
for (const row of existingUserInviteRoles) { for (const row of existingUserInviteRoles) {
await db.execute(sql` await db.execute(sql`
INSERT INTO "userInviteRoles" ("inviteId", "roleId") INSERT INTO "userInviteRoles" ("inviteId", "roleId")
SELECT ${row.inviteId}, ${row.roleId} VALUES (${row.inviteId}, ${row.roleId})
WHERE EXISTS (SELECT 1 FROM "userInvites" WHERE "inviteId" = ${row.inviteId})
AND EXISTS (SELECT 1 FROM "roles" WHERE "roleId" = ${row.roleId})
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
`); `);
} }
@@ -260,10 +258,7 @@ export default async function migration() {
for (const row of existingUserOrgRoles) { for (const row of existingUserOrgRoles) {
await db.execute(sql` await db.execute(sql`
INSERT INTO "userOrgRoles" ("userId", "orgId", "roleId") INSERT INTO "userOrgRoles" ("userId", "orgId", "roleId")
SELECT ${row.userId}, ${row.orgId}, ${row.roleId} VALUES (${row.userId}, ${row.orgId}, ${row.roleId})
WHERE EXISTS (SELECT 1 FROM "user" WHERE "id" = ${row.userId})
AND EXISTS (SELECT 1 FROM "orgs" WHERE "orgId" = ${row.orgId})
AND EXISTS (SELECT 1 FROM "roles" WHERE "roleId" = ${row.roleId})
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
`); `);
} }

View File

@@ -145,7 +145,7 @@ export default async function migration() {
).run(); ).run();
db.prepare( db.prepare(
`INSERT INTO '__new_userOrgs'("userId", "orgId", "isOwner", "autoProvisioned", "pamUsername") SELECT "userId", "orgId", "isOwner", "autoProvisioned", "pamUsername" FROM 'userOrgs' WHERE EXISTS (SELECT 1 FROM 'user' WHERE id = userOrgs.userId) AND EXISTS (SELECT 1 FROM 'orgs' WHERE orgId = userOrgs.orgId);` `INSERT INTO '__new_userOrgs'("userId", "orgId", "isOwner", "autoProvisioned", "pamUsername") SELECT "userId", "orgId", "isOwner", "autoProvisioned", "pamUsername" FROM 'userOrgs';`
).run(); ).run();
db.prepare(`DROP TABLE 'userOrgs';`).run(); db.prepare(`DROP TABLE 'userOrgs';`).run();
db.prepare( db.prepare(
@@ -246,15 +246,12 @@ export default async function migration() {
// Re-insert the preserved invite role assignments into the new userInviteRoles table // Re-insert the preserved invite role assignments into the new userInviteRoles table
if (existingUserInviteRoles.length > 0) { if (existingUserInviteRoles.length > 0) {
const insertUserInviteRole = db.prepare( const insertUserInviteRole = db.prepare(
`INSERT OR IGNORE INTO 'userInviteRoles' ("inviteId", "roleId") `INSERT OR IGNORE INTO 'userInviteRoles' ("inviteId", "roleId") VALUES (?, ?)`
SELECT ?, ?
WHERE EXISTS (SELECT 1 FROM 'userInvites' WHERE inviteId = ?)
AND EXISTS (SELECT 1 FROM 'roles' WHERE roleId = ?)`
); );
const insertAll = db.transaction(() => { const insertAll = db.transaction(() => {
for (const row of existingUserInviteRoles) { for (const row of existingUserInviteRoles) {
insertUserInviteRole.run(row.inviteId, row.roleId, row.inviteId, row.roleId); insertUserInviteRole.run(row.inviteId, row.roleId);
} }
}); });
@@ -268,16 +265,12 @@ export default async function migration() {
// Re-insert the preserved role assignments into the new userOrgRoles table // Re-insert the preserved role assignments into the new userOrgRoles table
if (existingUserOrgRoles.length > 0) { if (existingUserOrgRoles.length > 0) {
const insertUserOrgRole = db.prepare( const insertUserOrgRole = db.prepare(
`INSERT OR IGNORE INTO 'userOrgRoles' ("userId", "orgId", "roleId") `INSERT OR IGNORE INTO 'userOrgRoles' ("userId", "orgId", "roleId") VALUES (?, ?, ?)`
SELECT ?, ?, ?
WHERE EXISTS (SELECT 1 FROM 'user' WHERE id = ?)
AND EXISTS (SELECT 1 FROM 'orgs' WHERE orgId = ?)
AND EXISTS (SELECT 1 FROM 'roles' WHERE roleId = ?)`
); );
const insertAll = db.transaction(() => { const insertAll = db.transaction(() => {
for (const row of existingUserOrgRoles) { for (const row of existingUserOrgRoles) {
insertUserOrgRole.run(row.userId, row.orgId, row.roleId, row.userId, row.orgId, row.roleId); insertUserOrgRole.run(row.userId, row.orgId, row.roleId);
} }
}); });

View File

@@ -10,7 +10,6 @@ import { authCookieHeader } from "@app/lib/api/cookies";
import { GetDNSRecordsResponse } from "@server/routers/domain"; import { GetDNSRecordsResponse } from "@server/routers/domain";
import DNSRecordsTable from "@app/components/DNSRecordTable"; import DNSRecordsTable from "@app/components/DNSRecordTable";
import DomainCertForm from "@app/components/DomainCertForm"; import DomainCertForm from "@app/components/DomainCertForm";
import { build } from "@server/build";
interface DomainSettingsPageProps { interface DomainSettingsPageProps {
params: Promise<{ domainId: string; orgId: string }>; params: Promise<{ domainId: string; orgId: string }>;
@@ -66,14 +65,12 @@ export default async function DomainSettingsPage({
)} )}
</div> </div>
<div className="space-y-6"> <div className="space-y-6">
{build != "oss" && env.flags.usePangolinDns ? ( <DomainInfoCard
<DomainInfoCard failed={domain.failed}
failed={domain.failed} verified={domain.verified}
verified={domain.verified} type={domain.type}
type={domain.type} errorMessage={domain.errorMessage}
errorMessage={domain.errorMessage} />
/>
) : null}
<DNSRecordsTable records={dnsRecords} type={domain.type} /> <DNSRecordsTable records={dnsRecords} type={domain.type} />

View File

@@ -60,17 +60,17 @@ export default async function ClientResourcesPage(
id: siteResource.siteResourceId, id: siteResource.siteResourceId,
name: siteResource.name, name: siteResource.name,
orgId: params.orgId, orgId: params.orgId,
siteNames: siteResource.siteNames, siteName: siteResource.siteName,
siteAddresses: siteResource.siteAddresses || null, siteAddress: siteResource.siteAddress || null,
mode: siteResource.mode || ("port" as any), mode: siteResource.mode || ("port" as any),
// protocol: siteResource.protocol, // protocol: siteResource.protocol,
// proxyPort: siteResource.proxyPort, // proxyPort: siteResource.proxyPort,
siteIds: siteResource.siteIds, siteId: siteResource.siteId,
destination: siteResource.destination, destination: siteResource.destination,
// destinationPort: siteResource.destinationPort, // destinationPort: siteResource.destinationPort,
alias: siteResource.alias || null, alias: siteResource.alias || null,
aliasAddress: siteResource.aliasAddress || null, aliasAddress: siteResource.aliasAddress || null,
siteNiceIds: siteResource.siteNiceIds, siteNiceId: siteResource.siteNiceId,
niceId: siteResource.niceId, niceId: siteResource.niceId,
tcpPortRangeString: siteResource.tcpPortRangeString || null, tcpPortRangeString: siteResource.tcpPortRangeString || null,
udpPortRangeString: siteResource.udpPortRangeString || null, udpPortRangeString: siteResource.udpPortRangeString || null,

View File

@@ -678,7 +678,6 @@ function ProxyResourceTargetsForm({
getPaginationRowModel: getPaginationRowModel(), getPaginationRowModel: getPaginationRowModel(),
getSortedRowModel: getSortedRowModel(), getSortedRowModel: getSortedRowModel(),
getFilteredRowModel: getFilteredRowModel(), getFilteredRowModel: getFilteredRowModel(),
getRowId: (row) => String(row.targetId),
state: { state: {
pagination: { pagination: {
pageIndex: 0, pageIndex: 0,

View File

@@ -999,7 +999,6 @@ export default function Page() {
getPaginationRowModel: getPaginationRowModel(), getPaginationRowModel: getPaginationRowModel(),
getSortedRowModel: getSortedRowModel(), getSortedRowModel: getSortedRowModel(),
getFilteredRowModel: getFilteredRowModel(), getFilteredRowModel: getFilteredRowModel(),
getRowId: (row) => String(row.targetId),
state: { state: {
pagination: { pagination: {
pageIndex: 0, pageIndex: 0,

View File

@@ -21,7 +21,6 @@ import {
ArrowUp10Icon, ArrowUp10Icon,
ArrowUpDown, ArrowUpDown,
ArrowUpRight, ArrowUpRight,
ChevronDown,
ChevronsUpDownIcon, ChevronsUpDownIcon,
MoreHorizontal MoreHorizontal
} from "lucide-react"; } from "lucide-react";
@@ -44,14 +43,14 @@ export type InternalResourceRow = {
id: number; id: number;
name: string; name: string;
orgId: string; orgId: string;
siteNames: string[]; siteName: string;
siteAddresses: (string | null)[]; siteAddress: string | null;
siteIds: number[];
siteNiceIds: string[];
// mode: "host" | "cidr" | "port"; // mode: "host" | "cidr" | "port";
mode: "host" | "cidr"; mode: "host" | "cidr";
// protocol: string | null; // protocol: string | null;
// proxyPort: number | null; // proxyPort: number | null;
siteId: number;
siteNiceId: string;
destination: string; destination: string;
// destinationPort: number | null; // destinationPort: number | null;
alias: string | null; alias: string | null;
@@ -137,60 +136,6 @@ export default function ClientResourcesTable({
} }
}; };
function SiteCell({ resourceRow }: { resourceRow: InternalResourceRow }) {
const { siteNames, siteNiceIds, orgId } = resourceRow;
if (!siteNames || siteNames.length === 0) {
return <span>-</span>;
}
if (siteNames.length === 1) {
return (
<Link
href={`/${orgId}/settings/sites/${siteNiceIds[0]}`}
>
<Button variant="outline">
{siteNames[0]}
<ArrowUpRight className="ml-2 h-4 w-4" />
</Button>
</Link>
);
}
return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
variant="outline"
size="sm"
className="flex items-center gap-2"
>
<span>
{siteNames.length} {t("sites")}
</span>
<ChevronDown className="h-3 w-3" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="start">
{siteNames.map((siteName, idx) => (
<DropdownMenuItem
key={siteNiceIds[idx]}
asChild
>
<Link
href={`/${orgId}/settings/sites/${siteNiceIds[idx]}`}
className="flex items-center gap-2 cursor-pointer"
>
{siteName}
<ArrowUpRight className="h-3 w-3" />
</Link>
</DropdownMenuItem>
))}
</DropdownMenuContent>
</DropdownMenu>
);
}
const internalColumns: ExtendedColumnDef<InternalResourceRow>[] = [ const internalColumns: ExtendedColumnDef<InternalResourceRow>[] = [
{ {
accessorKey: "name", accessorKey: "name",
@@ -240,11 +185,21 @@ export default function ClientResourcesTable({
} }
}, },
{ {
accessorKey: "siteNames", accessorKey: "siteName",
friendlyName: t("site"), friendlyName: t("site"),
header: () => <span className="p-3">{t("site")}</span>, header: () => <span className="p-3">{t("site")}</span>,
cell: ({ row }) => { cell: ({ row }) => {
return <SiteCell resourceRow={row.original} />; const resourceRow = row.original;
return (
<Link
href={`/${resourceRow.orgId}/settings/sites/${resourceRow.siteNiceId}`}
>
<Button variant="outline">
{resourceRow.siteName}
<ArrowUpRight className="ml-2 h-4 w-4" />
</Button>
</Link>
);
} }
}, },
{ {
@@ -444,7 +399,7 @@ export default function ClientResourcesTable({
onConfirm={async () => onConfirm={async () =>
deleteInternalResource( deleteInternalResource(
selectedInternalResource!.id, selectedInternalResource!.id,
selectedInternalResource!.siteIds[0] selectedInternalResource!.siteId
) )
} }
string={selectedInternalResource.name} string={selectedInternalResource.name}
@@ -478,11 +433,7 @@ export default function ClientResourcesTable({
<EditInternalResourceDialog <EditInternalResourceDialog
open={isEditDialogOpen} open={isEditDialogOpen}
setOpen={setIsEditDialogOpen} setOpen={setIsEditDialogOpen}
resource={{ resource={editingResource}
...editingResource,
siteName: editingResource.siteNames[0] ?? "",
siteId: editingResource.siteIds[0]
}}
orgId={orgId} orgId={orgId}
sites={sites} sites={sites}
onSuccess={() => { onSuccess={() => {

View File

@@ -154,7 +154,7 @@ export default function CreateDomainForm({
const punycodePreview = useMemo(() => { const punycodePreview = useMemo(() => {
if (!baseDomain) return ""; if (!baseDomain) return "";
const punycode = toPunycode(baseDomain.toLowerCase()); const punycode = toPunycode(baseDomain);
return punycode !== baseDomain.toLowerCase() ? punycode : ""; return punycode !== baseDomain.toLowerCase() ? punycode : "";
}, [baseDomain]); }, [baseDomain]);
@@ -239,24 +239,21 @@ export default function CreateDomainForm({
className="space-y-4" className="space-y-4"
id="create-domain-form" id="create-domain-form"
> >
{build != "oss" && env.flags.usePangolinDns ? ( <FormField
<FormField control={form.control}
control={form.control} name="type"
name="type" render={({ field }) => (
render={({ field }) => ( <FormItem>
<FormItem> <StrategySelect
<StrategySelect options={domainOptions}
options={domainOptions} defaultValue={field.value}
defaultValue={field.value} onChange={field.onChange}
onChange={field.onChange} cols={1}
cols={1} />
/> <FormMessage />
<FormMessage /> </FormItem>
</FormItem> )}
)} />
/>
) : null}
<FormField <FormField
control={form.control} control={form.control}
name="baseDomain" name="baseDomain"

View File

@@ -319,7 +319,6 @@ export default function DeviceLoginForm({
<div className="flex justify-center"> <div className="flex justify-center">
<InputOTP <InputOTP
maxLength={9} maxLength={9}
pattern={REGEXP_ONLY_DIGITS_AND_CHARS}
{...field} {...field}
value={field.value value={field.value
.replace(/-/g, "") .replace(/-/g, "")

View File

@@ -2,7 +2,6 @@
import { Alert, AlertDescription } from "@/components/ui/alert"; import { Alert, AlertDescription } from "@/components/ui/alert";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Card, CardContent } from "@/components/ui/card";
import { import {
Command, Command,
CommandEmpty, CommandEmpty,
@@ -41,12 +40,9 @@ import {
Check, Check,
CheckCircle2, CheckCircle2,
ChevronsUpDown, ChevronsUpDown,
KeyRound,
Zap Zap
} from "lucide-react"; } from "lucide-react";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { usePaidStatus } from "@/hooks/usePaidStatus";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { toUnicode } from "punycode"; import { toUnicode } from "punycode";
import { useCallback, useEffect, useMemo, useState } from "react"; import { useCallback, useEffect, useMemo, useState } from "react";
@@ -99,7 +95,6 @@ export default function DomainPicker({
const { env } = useEnvContext(); const { env } = useEnvContext();
const api = createApiClient({ env }); const api = createApiClient({ env });
const t = useTranslations(); const t = useTranslations();
const { hasSaasSubscription } = usePaidStatus();
const { data = [], isLoading: loadingDomains } = useQuery( const { data = [], isLoading: loadingDomains } = useQuery(
orgQueries.domains({ orgId }) orgQueries.domains({ orgId })
@@ -514,11 +509,9 @@ export default function DomainPicker({
<span className="truncate"> <span className="truncate">
{selectedBaseDomain.domain} {selectedBaseDomain.domain}
</span> </span>
{selectedBaseDomain.verified && {selectedBaseDomain.verified && (
selectedBaseDomain.domainType !== <CheckCircle2 className="h-3 w-3 text-green-500 shrink-0" />
"wildcard" && ( )}
<CheckCircle2 className="h-3 w-3 text-green-500 shrink-0" />
)}
</div> </div>
) : ( ) : (
t("domainPickerSelectBaseDomain") t("domainPickerSelectBaseDomain")
@@ -581,23 +574,14 @@ export default function DomainPicker({
} }
</span> </span>
<span className="text-xs text-muted-foreground"> <span className="text-xs text-muted-foreground">
{orgDomain.type === {orgDomain.type.toUpperCase()}{" "}
"wildcard" {" "}
{orgDomain.verified
? t( ? t(
"domainPickerManual" "domainPickerVerified"
) )
: ( : t(
<> "domainPickerUnverified"
{orgDomain.type.toUpperCase()}{" "}
{" "}
{orgDomain.verified
? t(
"domainPickerVerified"
)
: t(
"domainPickerUnverified"
)}
</>
)} )}
</span> </span>
</div> </div>
@@ -696,23 +680,6 @@ export default function DomainPicker({
</div> </div>
</div> </div>
{build === "saas" &&
!hasSaasSubscription(
tierMatrix[TierFeature.DomainNamespaces]
) &&
!hideFreeDomain && (
<Card className="mt-3 border-black-500/30 bg-linear-to-br from-black-500/10 via-background to-background overflow-hidden">
<CardContent className="py-3 px-4">
<div className="flex items-center gap-2.5 text-sm text-muted-foreground">
<KeyRound className="size-4 shrink-0 text-black-500" />
<span>
{t("domainPickerFreeDomainsPaidFeature")}
</span>
</div>
</CardContent>
</Card>
)}
{/*showProvidedDomainSearch && build === "saas" && ( {/*showProvidedDomainSearch && build === "saas" && (
<Alert> <Alert>
<AlertCircle className="h-4 w-4" /> <AlertCircle className="h-4 w-4" />

View File

@@ -39,11 +39,7 @@ export default function InviteStatusCard({
const [loading, setLoading] = useState(true); const [loading, setLoading] = useState(true);
const [error, setError] = useState(""); const [error, setError] = useState("");
const [type, setType] = useState< const [type, setType] = useState<
| "rejected" "rejected" | "wrong_user" | "user_does_not_exist" | "not_logged_in" | "user_limit_exceeded"
| "wrong_user"
| "user_does_not_exist"
| "not_logged_in"
| "user_limit_exceeded"
>("rejected"); >("rejected");
useEffect(() => { useEffect(() => {
@@ -94,12 +90,12 @@ export default function InviteStatusCard({
if (!user && type === "user_does_not_exist") { if (!user && type === "user_does_not_exist") {
const redirectUrl = email const redirectUrl = email
? `/auth/signup?redirect=/invite?token=${tokenParam}&email=${email}` ? `/auth/signup?redirect=/invite?token=${tokenParam}&email=${encodeURIComponent(email)}`
: `/auth/signup?redirect=/invite?token=${tokenParam}`; : `/auth/signup?redirect=/invite?token=${tokenParam}`;
router.push(redirectUrl); router.push(redirectUrl);
} else if (!user && type === "not_logged_in") { } else if (!user && type === "not_logged_in") {
const redirectUrl = email const redirectUrl = email
? `/auth/login?redirect=/invite?token=${tokenParam}&email=${email}` ? `/auth/login?redirect=/invite?token=${tokenParam}&email=${encodeURIComponent(email)}`
: `/auth/login?redirect=/invite?token=${tokenParam}`; : `/auth/login?redirect=/invite?token=${tokenParam}`;
router.push(redirectUrl); router.push(redirectUrl);
} else { } else {
@@ -113,7 +109,7 @@ export default function InviteStatusCard({
async function goToLogin() { async function goToLogin() {
await api.post("/auth/logout", {}); await api.post("/auth/logout", {});
const redirectUrl = email const redirectUrl = email
? `/auth/login?redirect=/invite?token=${tokenParam}&email=${email}` ? `/auth/login?redirect=/invite?token=${tokenParam}&email=${encodeURIComponent(email)}`
: `/auth/login?redirect=/invite?token=${tokenParam}`; : `/auth/login?redirect=/invite?token=${tokenParam}`;
router.push(redirectUrl); router.push(redirectUrl);
} }
@@ -121,7 +117,7 @@ export default function InviteStatusCard({
async function goToSignup() { async function goToSignup() {
await api.post("/auth/logout", {}); await api.post("/auth/logout", {});
const redirectUrl = email const redirectUrl = email
? `/auth/signup?redirect=/invite?token=${tokenParam}&email=${email}` ? `/auth/signup?redirect=/invite?token=${tokenParam}&email=${encodeURIComponent(email)}`
: `/auth/signup?redirect=/invite?token=${tokenParam}`; : `/auth/signup?redirect=/invite?token=${tokenParam}`;
router.push(redirectUrl); router.push(redirectUrl);
} }
@@ -161,9 +157,7 @@ export default function InviteStatusCard({
Cannot Accept Invite Cannot Accept Invite
</p> </p>
<p className="text-center text-sm"> <p className="text-center text-sm">
This organization has reached its user limit. Please This organization has reached its user limit. Please contact the organization administrator to upgrade their plan before accepting this invite.
contact the organization administrator to upgrade their
plan before accepting this invite.
</p> </p>
</div> </div>
); );

View File

@@ -333,8 +333,7 @@ export default function PendingSitesTable({
"jupiter", "jupiter",
"saturn", "saturn",
"uranus", "uranus",
"neptune", "neptune"
"pluto"
].includes(originalRow.exitNodeName.toLowerCase()); ].includes(originalRow.exitNodeName.toLowerCase());
if (isCloudNode) { if (isCloudNode) {

View File

@@ -342,8 +342,7 @@ export default function SitesTable({
"jupiter", "jupiter",
"saturn", "saturn",
"uranus", "uranus",
"neptune", "neptune"
"pluto"
].includes(originalRow.exitNodeName.toLowerCase()); ].includes(originalRow.exitNodeName.toLowerCase());
if (isCloudNode) { if (isCloudNode) {

View File

@@ -164,7 +164,7 @@ const countryClass = cn(
const highlightedCountryClass = cn( const highlightedCountryClass = cn(
sharedCountryClass, sharedCountryClass,
"stroke-[3]", "stroke-2",
"fill-[#f4f4f5]", "fill-[#f4f4f5]",
"stroke-[#f36117]", "stroke-[#f36117]",
"dark:fill-[#3f3f46]" "dark:fill-[#3f3f46]"
@@ -194,20 +194,11 @@ function drawInteractiveCountries(
const path = setupProjetionPath(); const path = setupProjetionPath();
const data = parseWorldTopoJsonToGeoJsonFeatures(); const data = parseWorldTopoJsonToGeoJsonFeatures();
const svg = d3.select(element); const svg = d3.select(element);
const countriesLayer = svg.append("g");
const hoverLayer = svg.append("g").style("pointer-events", "none");
const hoverPath = hoverLayer
.append("path")
.datum(null)
.attr("class", highlightedCountryClass)
.style("display", "none");
countriesLayer svg.selectAll("path")
.selectAll("path")
.data(data) .data(data)
.enter() .enter()
.append("path") .append("path")
.attr("data-country-path", "true")
.attr("class", countryClass) .attr("class", countryClass)
.attr("d", path as never) .attr("d", path as never)
@@ -218,10 +209,9 @@ function drawInteractiveCountries(
y, y,
hoveredCountryAlpha3Code: country.properties.a3 hoveredCountryAlpha3Code: country.properties.a3
}); });
hoverPath // brings country to front
.datum(country) this.parentNode?.appendChild(this);
.attr("d", path(country) as string) d3.select(this).attr("class", highlightedCountryClass);
.style("display", null);
}) })
.on("mousemove", function (event) { .on("mousemove", function (event) {
@@ -231,7 +221,7 @@ function drawInteractiveCountries(
.on("mouseout", function () { .on("mouseout", function () {
setTooltip({ x: 0, y: 0, hoveredCountryAlpha3Code: null }); setTooltip({ x: 0, y: 0, hoveredCountryAlpha3Code: null });
hoverPath.style("display", "none"); d3.select(this).attr("class", countryClass);
}); });
return svg; return svg;
@@ -267,7 +257,7 @@ function colorInCountriesWithValues(
const svg = d3.select(element); const svg = d3.select(element);
return svg return svg
.selectAll('path[data-country-path="true"]') .selectAll("path")
.style("fill", (countryPath) => { .style("fill", (countryPath) => {
const country = getCountryByCountryPath(countryPath); const country = getCountryByCountryPath(countryPath);
if (!country?.count) { if (!country?.count) {