Merge branch 'dev' into feat/login-page-customization

This commit is contained in:
Fred KISSIE
2025-12-05 22:38:07 +01:00
275 changed files with 21920 additions and 6990 deletions

View File

@@ -86,6 +86,7 @@ export enum ActionsEnum {
updateOrgDomain = "updateOrgDomain",
getDNSRecords = "getDNSRecords",
createNewt = "createNewt",
createOlm = "createOlm",
createIdp = "createIdp",
updateIdp = "updateIdp",
deleteIdp = "deleteIdp",

View File

@@ -36,13 +36,15 @@ export async function createSession(
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token))
);
const session: Session = {
sessionId: sessionId,
userId,
expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(),
issuedAt: new Date().getTime()
};
await db.insert(sessions).values(session);
const [session] = await db
.insert(sessions)
.values({
sessionId: sessionId,
userId,
expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(),
issuedAt: new Date().getTime()
})
.returning();
return session;
}

View File

@@ -1,9 +1,43 @@
import { Request } from "express";
import { validateSessionToken, SESSION_COOKIE_NAME } from "@server/auth/sessions/app";
import {
validateSessionToken,
SESSION_COOKIE_NAME
} from "@server/auth/sessions/app";
export async function verifySession(req: Request) {
export async function verifySession(req: Request, forceLogin?: boolean) {
const res = await validateSessionToken(
req.cookies[SESSION_COOKIE_NAME] ?? "",
req.cookies[SESSION_COOKIE_NAME] ?? ""
);
if (!forceLogin) {
return res;
}
if (!res.session || !res.user) {
return {
session: null,
user: null
};
}
if (res.session.deviceAuthUsed) {
return {
session: null,
user: null
};
}
if (!res.session.issuedAt) {
return {
session: null,
user: null
};
}
const mins = 5 * 60 * 1000;
const now = new Date().getTime();
if (now - res.session.issuedAt > mins) {
return {
session: null,
user: null
};
}
return res;
}

View File

@@ -42,11 +42,17 @@ export async function getUniqueResourceName(orgId: string): Promise<string> {
}
const name = generateName();
const count = await db
.select({ niceId: resources.niceId, orgId: resources.orgId })
.from(resources)
.where(and(eq(resources.niceId, name), eq(resources.orgId, orgId)));
if (count.length === 0) {
const [resourceCount, siteResourceCount] = await Promise.all([
db
.select({ niceId: resources.niceId, orgId: resources.orgId })
.from(resources)
.where(and(eq(resources.niceId, name), eq(resources.orgId, orgId))),
db
.select({ niceId: siteResources.niceId, orgId: siteResources.orgId })
.from(siteResources)
.where(and(eq(siteResources.niceId, name), eq(siteResources.orgId, orgId)))
]);
if (resourceCount.length === 0 && siteResourceCount.length === 0) {
return name;
}
loops++;
@@ -61,11 +67,17 @@ export async function getUniqueSiteResourceName(orgId: string): Promise<string>
}
const name = generateName();
const count = await db
.select({ niceId: siteResources.niceId, orgId: siteResources.orgId })
.from(siteResources)
.where(and(eq(siteResources.niceId, name), eq(siteResources.orgId, orgId)));
if (count.length === 0) {
const [resourceCount, siteResourceCount] = await Promise.all([
db
.select({ niceId: resources.niceId, orgId: resources.orgId })
.from(resources)
.where(and(eq(resources.niceId, name), eq(resources.orgId, orgId))),
db
.select({ niceId: siteResources.niceId, orgId: siteResources.orgId })
.from(siteResources)
.where(and(eq(siteResources.niceId, name), eq(siteResources.orgId, orgId)))
]);
if (resourceCount.length === 0 && siteResourceCount.length === 0) {
return name;
}
loops++;

View File

@@ -73,7 +73,7 @@ function createDb() {
return withReplicas(
DrizzlePostgres(primaryPool, {
logger: process.env.NODE_ENV === "development"
logger: process.env.QUERY_LOGGING === "true"
}),
replicas as any
);

View File

@@ -12,6 +12,7 @@ import {
} from "drizzle-orm/pg-core";
import { InferSelectModel } from "drizzle-orm";
import { randomUUID } from "crypto";
import { alias } from "yargs";
export const domains = pgTable("domains", {
domainId: varchar("domainId").primaryKey(),
@@ -41,6 +42,7 @@ export const orgs = pgTable("orgs", {
orgId: varchar("orgId").primaryKey(),
name: varchar("name").notNull(),
subnet: varchar("subnet"),
utilitySubnet: varchar("utilitySubnet"), // this is the subnet for utility addresses
createdAt: text("createdAt"),
requireTwoFactor: boolean("requireTwoFactor"),
maxSessionLengthHours: integer("maxSessionLengthHours"),
@@ -89,8 +91,7 @@ export const sites = pgTable("sites", {
publicKey: varchar("publicKey"),
lastHolePunch: bigint("lastHolePunch", { mode: "number" }),
listenPort: integer("listenPort"),
dockerSocketEnabled: boolean("dockerSocketEnabled").notNull().default(true),
remoteSubnets: text("remoteSubnets") // comma-separated list of subnets that this site can access
dockerSocketEnabled: boolean("dockerSocketEnabled").notNull().default(true)
});
export const resources = pgTable("resources", {
@@ -206,11 +207,41 @@ export const siteResources = pgTable("siteResources", {
.references(() => orgs.orgId, { onDelete: "cascade" }),
niceId: varchar("niceId").notNull(),
name: varchar("name").notNull(),
protocol: varchar("protocol").notNull(),
proxyPort: integer("proxyPort").notNull(),
destinationPort: integer("destinationPort").notNull(),
destinationIp: varchar("destinationIp").notNull(),
enabled: boolean("enabled").notNull().default(true)
mode: varchar("mode").notNull(), // "host" | "cidr" | "port"
protocol: varchar("protocol"), // only for port mode
proxyPort: integer("proxyPort"), // only for port mode
destinationPort: integer("destinationPort"), // only for port mode
destination: varchar("destination").notNull(), // ip, cidr, hostname; validate against the mode
enabled: boolean("enabled").notNull().default(true),
alias: varchar("alias"),
aliasAddress: varchar("aliasAddress")
});
export const clientSiteResources = pgTable("clientSiteResources", {
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" }),
siteResourceId: integer("siteResourceId")
.notNull()
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
});
export const roleSiteResources = pgTable("roleSiteResources", {
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" }),
siteResourceId: integer("siteResourceId")
.notNull()
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
});
export const userSiteResources = pgTable("userSiteResources", {
userId: varchar("userId")
.notNull()
.references(() => users.userId, { onDelete: "cascade" }),
siteResourceId: integer("siteResourceId")
.notNull()
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
});
export const users = pgTable("user", {
@@ -258,7 +289,8 @@ export const sessions = pgTable("session", {
.notNull()
.references(() => users.userId, { onDelete: "cascade" }),
expiresAt: bigint("expiresAt", { mode: "number" }).notNull(),
issuedAt: bigint("issuedAt", { mode: "number" })
issuedAt: bigint("issuedAt", { mode: "number" }),
deviceAuthUsed: boolean("deviceAuthUsed").notNull().default(false)
});
export const newtSessions = pgTable("newtSession", {
@@ -600,7 +632,7 @@ export const idpOrg = pgTable("idpOrg", {
});
export const clients = pgTable("clients", {
clientId: serial("id").primaryKey(),
clientId: serial("clientId").primaryKey(),
orgId: varchar("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
@@ -609,6 +641,11 @@ export const clients = pgTable("clients", {
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null"
}),
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
}),
olmId: text("olmId"), // to lock it to a specific olm optionally
name: varchar("name").notNull(),
pubKey: varchar("pubKey"),
subnet: varchar("subnet").notNull(),
@@ -623,23 +660,40 @@ export const clients = pgTable("clients", {
maxConnections: integer("maxConnections")
});
export const clientSites = pgTable("clientSites", {
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" }),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
isRelayed: boolean("isRelayed").notNull().default(false),
endpoint: varchar("endpoint")
});
export const clientSitesAssociationsCache = pgTable(
"clientSitesAssociationsCache",
{
clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message
.notNull(),
siteId: integer("siteId").notNull(),
isRelayed: boolean("isRelayed").notNull().default(false),
endpoint: varchar("endpoint"),
publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
}
);
export const clientSiteResourcesAssociationsCache = pgTable(
"clientSiteResourcesAssociationsCache",
{
clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message
.notNull(),
siteResourceId: integer("siteResourceId").notNull()
}
);
export const olms = pgTable("olms", {
olmId: varchar("id").primaryKey(),
secretHash: varchar("secretHash").notNull(),
dateCreated: varchar("dateCreated").notNull(),
version: text("version"),
agent: text("agent"),
name: varchar("name"),
clientId: integer("clientId").references(() => clients.clientId, {
// we will switch this depending on the current org it wants to connect to
onDelete: "set null"
}),
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
})
});
@@ -755,6 +809,21 @@ export const requestAuditLog = pgTable(
]
);
export const deviceWebAuthCodes = pgTable("deviceWebAuthCodes", {
codeId: serial("codeId").primaryKey(),
code: text("code").notNull().unique(),
ip: text("ip"),
city: text("city"),
deviceName: text("deviceName"),
applicationName: text("applicationName").notNull(),
expiresAt: bigint("expiresAt", { mode: "number" }).notNull(),
createdAt: bigint("createdAt", { mode: "number" }).notNull(),
verified: boolean("verified").notNull().default(false),
userId: varchar("userId").references(() => users.userId, {
onDelete: "cascade"
})
});
export type Org = InferSelectModel<typeof orgs>;
export type User = InferSelectModel<typeof users>;
export type Site = InferSelectModel<typeof sites>;
@@ -795,7 +864,7 @@ export type ApiKey = InferSelectModel<typeof apiKeys>;
export type ApiKeyAction = InferSelectModel<typeof apiKeyActions>;
export type ApiKeyOrg = InferSelectModel<typeof apiKeyOrg>;
export type Client = InferSelectModel<typeof clients>;
export type ClientSite = InferSelectModel<typeof clientSites>;
export type ClientSite = InferSelectModel<typeof clientSitesAssociationsCache>;
export type Olm = InferSelectModel<typeof olms>;
export type OlmSession = InferSelectModel<typeof olmSessions>;
export type UserClient = InferSelectModel<typeof userClients>;
@@ -810,4 +879,5 @@ export type Blueprint = InferSelectModel<typeof blueprints>;
export type LicenseKey = InferSelectModel<typeof licenseKey>;
export type SecurityKey = InferSelectModel<typeof securityKeys>;
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;

View File

@@ -1,13 +1,12 @@
import {
sqliteTable,
integer,
text,
real,
index
} from "drizzle-orm/sqlite-core";
import { InferSelectModel } from "drizzle-orm";
import { domains, orgs, targets, users, exitNodes, sessions } from "./schema";
import { metadata } from "@app/app/[orgId]/settings/layout";
import {
index,
integer,
real,
sqliteTable,
text
} from "drizzle-orm/sqlite-core";
import { domains, exitNodes, orgs, sessions, users } from "./schema";
export const certificates = sqliteTable("certificates", {
certId: integer("certId").primaryKey({ autoIncrement: true }),

View File

@@ -7,7 +7,7 @@ import {
index,
uniqueIndex
} from "drizzle-orm/sqlite-core";
import { boolean } from "yargs";
import { no } from "zod/v4/locales";
export const domains = sqliteTable("domains", {
domainId: text("domainId").primaryKey(),
@@ -39,6 +39,7 @@ export const orgs = sqliteTable("orgs", {
orgId: text("orgId").primaryKey(),
name: text("name").notNull(),
subnet: text("subnet"),
utilitySubnet: text("utilitySubnet"), // this is the subnet for utility addresses
createdAt: text("createdAt"),
requireTwoFactor: integer("requireTwoFactor", { mode: "boolean" }),
maxSessionLengthHours: integer("maxSessionLengthHours"), // hours
@@ -100,8 +101,7 @@ export const sites = sqliteTable("sites", {
listenPort: integer("listenPort"),
dockerSocketEnabled: integer("dockerSocketEnabled", { mode: "boolean" })
.notNull()
.default(true),
remoteSubnets: text("remoteSubnets") // comma-separated list of subnets that this site can access
.default(true)
});
export const resources = sqliteTable("resources", {
@@ -202,7 +202,7 @@ export const targetHealthCheck = sqliteTable("targetHealthCheck", {
hcMethod: text("hcMethod").default("GET"),
hcStatus: integer("hcStatus"), // http code
hcHealth: text("hcHealth").default("unknown"), // "unknown", "healthy", "unhealthy"
hcTlsServerName: text("hcTlsServerName"),
hcTlsServerName: text("hcTlsServerName")
});
export const exitNodes = sqliteTable("exitNodes", {
@@ -233,11 +233,41 @@ export const siteResources = sqliteTable("siteResources", {
.references(() => orgs.orgId, { onDelete: "cascade" }),
niceId: text("niceId").notNull(),
name: text("name").notNull(),
protocol: text("protocol").notNull(),
proxyPort: integer("proxyPort").notNull(),
destinationPort: integer("destinationPort").notNull(),
destinationIp: text("destinationIp").notNull(),
enabled: integer("enabled", { mode: "boolean" }).notNull().default(true)
mode: text("mode").notNull(), // "host" | "cidr" | "port"
protocol: text("protocol"), // only for port mode
proxyPort: integer("proxyPort"), // only for port mode
destinationPort: integer("destinationPort"), // only for port mode
destination: text("destination").notNull(), // ip, cidr, hostname
enabled: integer("enabled", { mode: "boolean" }).notNull().default(true),
alias: text("alias"),
aliasAddress: text("aliasAddress")
});
export const clientSiteResources = sqliteTable("clientSiteResources", {
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" }),
siteResourceId: integer("siteResourceId")
.notNull()
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
});
export const roleSiteResources = sqliteTable("roleSiteResources", {
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" }),
siteResourceId: integer("siteResourceId")
.notNull()
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
});
export const userSiteResources = sqliteTable("userSiteResources", {
userId: text("userId")
.notNull()
.references(() => users.userId, { onDelete: "cascade" }),
siteResourceId: integer("siteResourceId")
.notNull()
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
});
export const users = sqliteTable("user", {
@@ -313,7 +343,7 @@ export const newts = sqliteTable("newt", {
});
export const clients = sqliteTable("clients", {
clientId: integer("id").primaryKey({ autoIncrement: true }),
clientId: integer("clientId").primaryKey({ autoIncrement: true }),
orgId: text("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
@@ -322,8 +352,14 @@ export const clients = sqliteTable("clients", {
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null"
}),
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
}),
name: text("name").notNull(),
pubKey: text("pubKey"),
olmId: text("olmId"), // to lock it to a specific olm optionally
subnet: text("subnet").notNull(),
megabytesIn: integer("bytesIn"),
megabytesOut: integer("bytesOut"),
@@ -335,25 +371,42 @@ export const clients = sqliteTable("clients", {
lastHolePunch: integer("lastHolePunch")
});
export const clientSites = sqliteTable("clientSites", {
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" }),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
isRelayed: integer("isRelayed", { mode: "boolean" })
.notNull()
.default(false),
endpoint: text("endpoint")
});
export const clientSitesAssociationsCache = sqliteTable(
"clientSitesAssociationsCache",
{
clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message
.notNull(),
siteId: integer("siteId").notNull(),
isRelayed: integer("isRelayed", { mode: "boolean" })
.notNull()
.default(false),
endpoint: text("endpoint"),
publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
}
);
export const clientSiteResourcesAssociationsCache = sqliteTable(
"clientSiteResourcesAssociationsCache",
{
clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message
.notNull(),
siteResourceId: integer("siteResourceId").notNull()
}
);
export const olms = sqliteTable("olms", {
olmId: text("id").primaryKey(),
secretHash: text("secretHash").notNull(),
dateCreated: text("dateCreated").notNull(),
version: text("version"),
agent: text("agent"),
name: text("name"),
clientId: integer("clientId").references(() => clients.clientId, {
// we will switch this depending on the current org it wants to connect to
onDelete: "set null"
}),
userId: text("userId").references(() => users.userId, {
// optionally tied to a user and in this case delete when the user deletes
onDelete: "cascade"
})
});
@@ -372,7 +425,10 @@ export const sessions = sqliteTable("session", {
.notNull()
.references(() => users.userId, { onDelete: "cascade" }),
expiresAt: integer("expiresAt").notNull(),
issuedAt: integer("issuedAt")
issuedAt: integer("issuedAt"),
deviceAuthUsed: integer("deviceAuthUsed", { mode: "boolean" })
.notNull()
.default(false)
});
export const newtSessions = sqliteTable("newtSession", {
@@ -809,6 +865,21 @@ export const requestAuditLog = sqliteTable(
]
);
export const deviceWebAuthCodes = sqliteTable("deviceWebAuthCodes", {
codeId: integer("codeId").primaryKey({ autoIncrement: true }),
code: text("code").notNull().unique(),
ip: text("ip"),
city: text("city"),
deviceName: text("deviceName"),
applicationName: text("applicationName").notNull(),
expiresAt: integer("expiresAt").notNull(),
createdAt: integer("createdAt").notNull(),
verified: integer("verified", { mode: "boolean" }).notNull().default(false),
userId: text("userId").references(() => users.userId, {
onDelete: "cascade"
})
});
export type Org = InferSelectModel<typeof orgs>;
export type User = InferSelectModel<typeof users>;
export type Site = InferSelectModel<typeof sites>;
@@ -847,7 +918,7 @@ export type ResourceRule = InferSelectModel<typeof resourceRules>;
export type Domain = InferSelectModel<typeof domains>;
export type DnsRecord = InferSelectModel<typeof dnsRecords>;
export type Client = InferSelectModel<typeof clients>;
export type ClientSite = InferSelectModel<typeof clientSites>;
export type ClientSite = InferSelectModel<typeof clientSitesAssociationsCache>;
export type RoleClient = InferSelectModel<typeof roleClients>;
export type UserClient = InferSelectModel<typeof userClients>;
export type SupporterKey = InferSelectModel<typeof supporterKey>;
@@ -866,3 +937,4 @@ export type LicenseKey = InferSelectModel<typeof licenseKey>;
export type SecurityKey = InferSelectModel<typeof securityKeys>;
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;

View File

@@ -11,6 +11,7 @@ import {
ApiKeyOrg,
RemoteExitNode,
Session,
SiteResource,
User,
UserOrg
} from "@server/db";
@@ -77,6 +78,8 @@ declare global {
userOrgId?: string;
userOrgIds?: string[];
remoteExitNode?: RemoteExitNode;
siteResource?: SiteResource;
orgPolicyAllowed?: boolean;
}
}
}

View File

@@ -122,19 +122,17 @@ export async function applyBlueprint({
)
.limit(1);
if (site) {
logger.debug(
`Updating client resource ${result.resource.siteResourceId} on site ${site.sites.siteId}`
);
logger.debug(
`Updating client resource ${result.resource.siteResourceId} on site ${site.sites.siteId}`
);
await addClientTargets(
site.newt.newtId,
result.resource.destinationIp,
result.resource.destinationPort,
result.resource.protocol,
result.resource.proxyPort
);
}
// await addClientTargets(
// site.newt.newtId,
// result.resource.destination,
// result.resource.destinationPort,
// result.resource.protocol,
// result.resource.proxyPort
// );
}
blueprintSucceeded = true;

View File

@@ -75,8 +75,9 @@ export async function updateClientResources(
.set({
name: resourceData.name || resourceNiceId,
siteId: site.siteId,
mode: "port",
proxyPort: resourceData["proxy-port"]!,
destinationIp: resourceData.hostname,
destination: resourceData.hostname,
destinationPort: resourceData["internal-port"],
protocol: resourceData.protocol
})
@@ -98,8 +99,9 @@ export async function updateClientResources(
siteId: site.siteId,
niceId: resourceNiceId,
name: resourceData.name || resourceNiceId,
mode: "port",
proxyPort: resourceData["proxy-port"]!,
destinationIp: resourceData.hostname,
destination: resourceData.hostname,
destinationPort: resourceData["internal-port"],
protocol: resourceData.protocol
})

View File

@@ -221,6 +221,7 @@ export async function updateProxyResources(
domainId: domain ? domain.domainId : null,
enabled: resourceEnabled,
sso: resourceData.auth?.["sso-enabled"] || false,
skipToIdpId: resourceData.auth?.["auto-login-idp"] || null,
ssl: resourceSsl,
setHostHeader: resourceData["host-header"] || null,
tlsServerName: resourceData["tls-server-name"] || null,
@@ -610,6 +611,7 @@ export async function updateProxyResources(
domainId: domain ? domain.domainId : null,
enabled: resourceEnabled,
sso: resourceData.auth?.["sso-enabled"] || false,
skipToIdpId: resourceData.auth?.["auto-login-idp"] || null,
setHostHeader: resourceData["host-header"] || null,
tlsServerName: resourceData["tls-server-name"] || null,
ssl: resourceSsl,
@@ -789,10 +791,6 @@ async function syncRoleResources(
.where(eq(roleResources.resourceId, resourceId));
for (const roleName of ssoRoles) {
if (roleName === "Admin") {
continue; // never add admin access
}
const [role] = await trx
.select()
.from(roles)
@@ -803,6 +801,10 @@ async function syncRoleResources(
throw new Error(`Role not found: ${roleName} in org ${orgId}`);
}
if (role.isAdmin) {
continue; // never add admin access
}
const existingRoleResource = existingRoleResources.find(
(rr) => rr.roleId === role.roleId
);

View File

@@ -59,6 +59,7 @@ export const AuthSchema = z.object({
}),
"sso-users": z.array(z.email()).optional().default([]),
"whitelist-users": z.array(z.email()).optional().default([]),
"auto-login-idp": z.int().positive().optional(),
});
export const RuleSchema = z.object({

View File

@@ -0,0 +1,286 @@
import {
clients,
db,
olms,
orgs,
roleClients,
roles,
userClients,
userOrgs,
Transaction
} from "@server/db";
import { eq, and, notInArray } from "drizzle-orm";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
import logger from "@server/logger";
import { rebuildClientAssociationsFromClient } from "./rebuildClientAssociations";
import { sendTerminateClient } from "@server/routers/client/terminate";
export async function calculateUserClientsForOrgs(
userId: string,
trx?: Transaction
): Promise<void> {
const execute = async (transaction: Transaction) => {
// Get all OLMs for this user
const userOlms = await transaction
.select()
.from(olms)
.where(eq(olms.userId, userId));
if (userOlms.length === 0) {
// No OLMs for this user, but we should still clean up any orphaned clients
await cleanupOrphanedClients(userId, transaction);
return;
}
// Get all user orgs
const allUserOrgs = await transaction
.select()
.from(userOrgs)
.where(eq(userOrgs.userId, userId));
const userOrgIds = allUserOrgs.map((uo) => uo.orgId);
// For each OLM, ensure there's a client in each org the user is in
for (const olm of userOlms) {
for (const userOrg of allUserOrgs) {
const orgId = userOrg.orgId;
const [org] = await transaction
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId));
if (!org) {
logger.warn(
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): org not found`
);
continue;
}
if (!org.subnet) {
logger.warn(
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): org has no subnet configured`
);
continue;
}
// Get admin role for this org (needed for access grants)
const [adminRole] = await transaction
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (!adminRole) {
logger.warn(
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): no admin role found`
);
continue;
}
// Check if a client already exists for this OLM+user+org combination
const [existingClient] = await transaction
.select()
.from(clients)
.where(
and(
eq(clients.userId, userId),
eq(clients.orgId, orgId),
eq(clients.olmId, olm.olmId)
)
)
.limit(1);
if (existingClient) {
// Ensure admin role has access to the client
const [existingRoleClient] = await transaction
.select()
.from(roleClients)
.where(
and(
eq(roleClients.roleId, adminRole.roleId),
eq(
roleClients.clientId,
existingClient.clientId
)
)
)
.limit(1);
if (!existingRoleClient) {
await transaction.insert(roleClients).values({
roleId: adminRole.roleId,
clientId: existingClient.clientId
});
logger.debug(
`Granted admin role access to existing client ${existingClient.clientId} for OLM ${olm.olmId} in org ${orgId} (user ${userId})`
);
}
// Ensure user has access to the client
const [existingUserClient] = await transaction
.select()
.from(userClients)
.where(
and(
eq(userClients.userId, userId),
eq(
userClients.clientId,
existingClient.clientId
)
)
)
.limit(1);
if (!existingUserClient) {
await transaction.insert(userClients).values({
userId,
clientId: existingClient.clientId
});
logger.debug(
`Granted user access to existing client ${existingClient.clientId} for OLM ${olm.olmId} in org ${orgId} (user ${userId})`
);
}
logger.debug(
`Client already exists for OLM ${olm.olmId} in org ${orgId} (user ${userId}), skipping creation`
);
continue;
}
// Get exit nodes for this org
const exitNodesList = await listExitNodes(orgId);
if (exitNodesList.length === 0) {
logger.warn(
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): no exit nodes found`
);
continue;
}
const randomExitNode =
exitNodesList[
Math.floor(Math.random() * exitNodesList.length)
];
// Get next available subnet
const newSubnet = await getNextAvailableClientSubnet(orgId);
if (!newSubnet) {
logger.warn(
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): no available subnet found`
);
continue;
}
const subnet = newSubnet.split("/")[0];
const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`;
// Create the client
const [newClient] = await transaction
.insert(clients)
.values({
userId,
orgId: userOrg.orgId,
exitNodeId: randomExitNode.exitNodeId,
name: olm.name || "User Client",
subnet: updatedSubnet,
olmId: olm.olmId,
type: "olm"
})
.returning();
await rebuildClientAssociationsFromClient(
newClient,
transaction
);
// Grant admin role access to the client
await transaction.insert(roleClients).values({
roleId: adminRole.roleId,
clientId: newClient.clientId
});
// Grant user access to the client
await transaction.insert(userClients).values({
userId,
clientId: newClient.clientId
});
logger.debug(
`Created client for OLM ${olm.olmId} in org ${orgId} (user ${userId}) with access granted to admin role and user`
);
}
}
// Clean up clients in orgs the user is no longer in
await cleanupOrphanedClients(userId, transaction, userOrgIds);
};
if (trx) {
// Use provided transaction
await execute(trx);
} else {
// Create new transaction
await db.transaction(async (transaction) => {
await execute(transaction);
});
}
}
async function cleanupOrphanedClients(
userId: string,
trx: Transaction,
userOrgIds: string[] = []
): Promise<void> {
// Find all OLM clients for this user that should be deleted
// If userOrgIds is empty, delete all OLM clients (user has no orgs)
// If userOrgIds has values, delete clients in orgs they're not in
const clientsToDelete = await trx
.select({ clientId: clients.clientId })
.from(clients)
.where(
userOrgIds.length > 0
? and(
eq(clients.userId, userId),
notInArray(clients.orgId, userOrgIds)
)
: and(eq(clients.userId, userId))
);
if (clientsToDelete.length > 0) {
const deletedClients = await trx
.delete(clients)
.where(
userOrgIds.length > 0
? and(
eq(clients.userId, userId),
notInArray(clients.orgId, userOrgIds)
)
: and(eq(clients.userId, userId))
)
.returning();
// Rebuild associations for each deleted client to clean up related data
for (const deletedClient of deletedClients) {
await rebuildClientAssociationsFromClient(deletedClient, trx);
if (deletedClient.olmId) {
await sendTerminateClient(
deletedClient.clientId,
deletedClient.olmId
);
}
}
if (userOrgIds.length === 0) {
logger.debug(
`Deleted all ${clientsToDelete.length} OLM client(s) for user ${userId} (user has no orgs)`
);
} else {
logger.debug(
`Deleted ${clientsToDelete.length} orphaned OLM client(s) for user ${userId} in orgs they're no longer in`
);
}
}
}

View File

@@ -85,10 +85,6 @@ export class Config {
? "true"
: "false";
process.env.FLAGS_ENABLE_CLIENTS = parsedConfig.flags?.enable_clients
? "true"
: "false";
process.env.PRODUCT_UPDATES_NOTIFICATION_ENABLED = parsedConfig.app
.notifications.product_updates
? "true"

View File

@@ -2,7 +2,7 @@ import path from "path";
import { fileURLToPath } from "url";
// This is a placeholder value replaced by the build process
export const APP_VERSION = "1.12.1";
export const APP_VERSION = "1.12.3";
export const __FILENAME = fileURLToPath(import.meta.url);
export const __DIRNAME = path.dirname(__FILENAME);

View File

@@ -18,6 +18,7 @@ import { defaultRoleAllowedActions } from "@server/routers/role";
import { FeatureId, limitsService, sandboxLimitSet } from "@server/lib/billing";
import { createCustomer } from "#dynamic/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import config from "@server/lib/config";
export async function createUserAccountOrg(
userId: string,
@@ -76,6 +77,8 @@ export async function createUserAccountOrg(
.from(domains)
.where(eq(domains.configManaged, true));
const utilitySubnet = config.getRawConfig().orgs.utility_subnet_group;
const newOrg = await trx
.insert(orgs)
.values({
@@ -83,6 +86,7 @@ export async function createUserAccountOrg(
name,
// subnet
subnet: "100.90.128.0/24", // TODO: this should not be hardcoded - or can it be the same in all orgs?
utilitySubnet: utilitySubnet,
createdAt: new Date().toISOString()
})
.returning();

View File

@@ -1,7 +1,15 @@
import { db } from "@server/db";
import {
clientSitesAssociationsCache,
db,
SiteResource,
siteResources,
Transaction
} from "@server/db";
import { clients, orgs, sites } from "@server/db";
import { and, eq, isNotNull } from "drizzle-orm";
import config from "@server/lib/config";
import z from "zod";
import logger from "@server/logger";
interface IPRange {
start: bigint;
@@ -279,6 +287,56 @@ export async function getNextAvailableClientSubnet(
return subnet;
}
export async function getNextAvailableAliasAddress(
orgId: string
): Promise<string> {
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
throw new Error(`Organization with ID ${orgId} not found`);
}
if (!org.subnet) {
throw new Error(`Organization with ID ${orgId} has no subnet defined`);
}
if (!org.utilitySubnet) {
throw new Error(
`Organization with ID ${orgId} has no utility subnet defined`
);
}
const existingAddresses = await db
.select({
aliasAddress: siteResources.aliasAddress
})
.from(siteResources)
.where(
and(
isNotNull(siteResources.aliasAddress),
eq(siteResources.orgId, orgId)
)
);
const addresses = [
...existingAddresses.map(
(site) => `${site.aliasAddress?.split("/")[0]}/32`
),
// reserve a /29 for the dns server and other stuff
`${org.utilitySubnet.split("/")[0]}/29`
].filter((address) => address !== null) as string[];
let subnet = findNextAvailableCidr(addresses, 32, org.utilitySubnet);
if (!subnet) {
throw new Error("No available subnets remaining in space");
}
// remove the cidr
subnet = subnet.split("/")[0];
return subnet;
}
export async function getNextAvailableOrgSubnet(): Promise<string> {
const existingAddresses = await db
.select({
@@ -300,3 +358,113 @@ export async function getNextAvailableOrgSubnet(): Promise<string> {
return subnet;
}
export function generateRemoteSubnets(allSiteResources: SiteResource[]): string[] {
const remoteSubnets = allSiteResources
.filter((sr) => {
if (sr.mode === "cidr") return true;
if (sr.mode === "host") {
// check if its a valid IP using zod
const ipSchema = z.union([z.ipv4(), z.ipv6()]);
const parseResult = ipSchema.safeParse(sr.destination);
return parseResult.success;
}
return false;
})
.map((sr) => {
if (sr.mode === "cidr") return sr.destination;
if (sr.mode === "host") {
return `${sr.destination}/32`;
}
return ""; // This should never be reached due to filtering, but satisfies TypeScript
})
.filter((subnet) => subnet !== ""); // Remove empty strings just to be safe
// remove duplicates
return Array.from(new Set(remoteSubnets));
}
export type Alias = { alias: string | null; aliasAddress: string | null };
export function generateAliasConfig(allSiteResources: SiteResource[]): Alias[] {
let aliasConfigs = allSiteResources
.filter((sr) => sr.alias && sr.aliasAddress && sr.mode == "host")
.map((sr) => ({
alias: sr.alias,
aliasAddress: sr.aliasAddress
}));
return aliasConfigs;
}
export type SubnetProxyTarget = {
sourcePrefix: string; // must be a cidr
destPrefix: string; // must be a cidr
rewriteTo?: string; // must be a cidr
portRange?: {
min: number;
max: number;
}[];
};
export function generateSubnetProxyTargets(
siteResource: SiteResource,
clients: {
clientId: number;
pubKey: string | null;
subnet: string | null;
}[]
): SubnetProxyTarget[] {
const targets: SubnetProxyTarget[] = [];
if (clients.length === 0) {
logger.debug(
`No clients have access to site resource ${siteResource.siteResourceId}, skipping target generation.`
);
return [];
}
for (const clientSite of clients) {
if (!clientSite.subnet) {
logger.debug(
`Client ${clientSite.clientId} has no subnet, skipping for site resource ${siteResource.siteResourceId}.`
);
continue;
}
const clientPrefix = `${clientSite.subnet.split("/")[0]}/32`;
if (siteResource.mode == "host") {
let destination = siteResource.destination;
// check if this is a valid ip
const ipSchema = z.union([z.ipv4(), z.ipv6()]);
if (ipSchema.safeParse(destination).success) {
destination = `${destination}/32`;
targets.push({
sourcePrefix: clientPrefix,
destPrefix: destination
});
}
if (siteResource.alias && siteResource.aliasAddress) {
// also push a match for the alias address
targets.push({
sourcePrefix: clientPrefix,
destPrefix: `${siteResource.aliasAddress}/32`,
rewriteTo: destination
});
}
} else if (siteResource.mode == "cidr") {
targets.push({
sourcePrefix: clientPrefix,
destPrefix: siteResource.destination
});
}
}
// print a nice representation of the targets
// logger.debug(
// `Generated subnet proxy targets for: ${JSON.stringify(targets, null, 2)}`
// );
return targets;
}

111
server/lib/lock.ts Normal file
View File

@@ -0,0 +1,111 @@
export class LockManager {
/**
* Acquire a distributed lock using Redis SET with NX and PX options
* @param lockKey - Unique identifier for the lock
* @param ttlMs - Time to live in milliseconds
* @returns Promise<boolean> - true if lock acquired, false otherwise
*/
async acquireLock(
lockKey: string,
ttlMs: number = 30000
): Promise<boolean> {
return true;
}
/**
* Release a lock using Lua script to ensure atomicity
* @param lockKey - Unique identifier for the lock
*/
async releaseLock(lockKey: string): Promise<void> {}
/**
* Force release a lock regardless of owner (use with caution)
* @param lockKey - Unique identifier for the lock
*/
async forceReleaseLock(lockKey: string): Promise<void> {}
/**
* Check if a lock exists and get its info
* @param lockKey - Unique identifier for the lock
* @returns Promise<{exists: boolean, ownedByMe: boolean, ttl: number}>
*/
async getLockInfo(lockKey: string): Promise<{
exists: boolean;
ownedByMe: boolean;
ttl: number;
owner?: string;
}> {
return { exists: true, ownedByMe: true, ttl: 0 };
}
/**
* Extend the TTL of an existing lock owned by this worker
* @param lockKey - Unique identifier for the lock
* @param ttlMs - New TTL in milliseconds
* @returns Promise<boolean> - true if extended successfully
*/
async extendLock(lockKey: string, ttlMs: number): Promise<boolean> {
return true;
}
/**
* Attempt to acquire lock with retries and exponential backoff
* @param lockKey - Unique identifier for the lock
* @param ttlMs - Time to live in milliseconds
* @param maxRetries - Maximum number of retry attempts
* @param baseDelayMs - Base delay between retries in milliseconds
* @returns Promise<boolean> - true if lock acquired
*/
async acquireLockWithRetry(
lockKey: string,
ttlMs: number = 30000,
maxRetries: number = 5,
baseDelayMs: number = 100
): Promise<boolean> {
return true;
}
/**
* Execute a function while holding a lock
* @param lockKey - Unique identifier for the lock
* @param fn - Function to execute while holding the lock
* @param ttlMs - Lock TTL in milliseconds
* @returns Promise<T> - Result of the executed function
*/
async withLock<T>(
lockKey: string,
fn: () => Promise<T>,
ttlMs: number = 30000
): Promise<T> {
const acquired = await this.acquireLock(lockKey, ttlMs);
if (!acquired) {
throw new Error(`Failed to acquire lock: ${lockKey}`);
}
try {
return await fn();
} finally {
await this.releaseLock(lockKey);
}
}
/**
* Clean up expired locks - Redis handles this automatically, but this method
* can be used to get statistics about locks
* @returns Promise<{activeLocksCount: number, locksOwnedByMe: number}>
*/
async getLockStatistics(): Promise<{
activeLocksCount: number;
locksOwnedByMe: number;
}> {
return { activeLocksCount: 0, locksOwnedByMe: 0 };
}
/**
* Close the Redis connection
*/
async disconnect(): Promise<void> {}
}
export const lockManager = new LockManager();

View File

@@ -229,6 +229,11 @@ export const configSchema = z
.default(51820)
.transform(stoi)
.pipe(portSchema),
clients_start_port: portSchema
.optional()
.default(21820)
.transform(stoi)
.pipe(portSchema),
base_endpoint: z
.string()
.optional()
@@ -249,12 +254,14 @@ export const configSchema = z
orgs: z
.object({
block_size: z.number().positive().gt(0).optional().default(24),
subnet_group: z.string().optional().default("100.90.128.0/24")
subnet_group: z.string().optional().default("100.90.128.0/24"),
utility_subnet_group: z.string().optional().default("100.96.128.0/24") //just hardcode this for now as well
})
.optional()
.default({
block_size: 24,
subnet_group: "100.90.128.0/24"
subnet_group: "100.90.128.0/24",
utility_subnet_group: "100.96.128.0/24"
}),
rate_limits: z
.object({
@@ -318,8 +325,7 @@ export const configSchema = z
enable_integration_api: z.boolean().optional(),
disable_local_sites: z.boolean().optional(),
disable_basic_wireguard_sites: z.boolean().optional(),
disable_config_managed_domains: z.boolean().optional(),
enable_clients: z.boolean().optional().default(true)
disable_config_managed_domains: z.boolean().optional()
})
.optional(),
dns: z

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ import { getHostMeta } from "./hostMeta";
import logger from "@server/logger";
import { apiKeys, db, roles } from "@server/db";
import { sites, users, orgs, resources, clients, idp } from "@server/db";
import { eq, count, notInArray } from "drizzle-orm";
import { eq, count, notInArray, and } from "drizzle-orm";
import { APP_VERSION } from "./consts";
import crypto from "crypto";
import { UserType } from "@server/types/UserTypes";
@@ -113,7 +113,12 @@ class TelemetryClient {
const [customRoles] = await db
.select({ count: count() })
.from(roles)
.where(notInArray(roles.name, ["Admin", "Member"]));
.where(
and(
eq(roles.isAdmin, false),
notInArray(roles.name, ["Member"])
)
);
const adminUsers = await db
.select({ email: users.email })

View File

@@ -345,9 +345,9 @@ export async function getTraefikConfig(
routerMiddlewares.push(rewriteMiddlewareName);
}
logger.debug(
`Created path rewrite middleware ${rewriteMiddlewareName}: ${resource.pathMatchType}(${resource.path}) -> ${resource.rewritePathType}(${resource.rewritePath})`
);
// logger.debug(
// `Created path rewrite middleware ${rewriteMiddlewareName}: ${resource.pathMatchType}(${resource.path}) -> ${resource.rewritePathType}(${resource.rewritePath})`
// );
} catch (error) {
logger.error(
`Failed to create path rewrite middleware for resource ${resource.resourceId}: ${error}`

View File

@@ -11,6 +11,7 @@ export * from "./verifyRoleAccess";
export * from "./verifyUserAccess";
export * from "./verifyAdmin";
export * from "./verifySetResourceUsers";
export * from "./verifySetResourceClients";
export * from "./verifyUserInRole";
export * from "./verifyAccessTokenAccess";
export * from "./requestTimeout";
@@ -24,7 +25,7 @@ export * from "./integration";
export * from "./verifyUserHasAction";
export * from "./verifyApiKeyAccess";
export * from "./verifyDomainAccess";
export * from "./verifyClientsEnabled";
export * from "./verifyUserIsOrgOwner";
export * from "./verifySiteResourceAccess";
export * from "./logActionAudit";
export * from "./logActionAudit";
export * from "./verifyOlmAccess";

View File

@@ -7,6 +7,7 @@ export * from "./verifyApiKeyTargetAccess";
export * from "./verifyApiKeyRoleAccess";
export * from "./verifyApiKeyUserAccess";
export * from "./verifyApiKeySetResourceUsers";
export * from "./verifyApiKeySetResourceClients";
export * from "./verifyAccessTokenAccess";
export * from "./verifyApiKeyIsRoot";
export * from "./verifyApiKeyApiKeyAccess";

View File

@@ -0,0 +1,73 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { clients } from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
export async function verifyApiKeySetResourceClients(
req: Request,
res: Response,
next: NextFunction
) {
const apiKey = req.apiKey;
const singleClientId = req.params.clientId || req.body.clientId || req.query.clientId;
const { clientIds } = req.body;
const allClientIds = clientIds || (singleClientId ? [parseInt(singleClientId as string)] : []);
if (!apiKey) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Key not authenticated")
);
}
if (apiKey.isRoot) {
// Root keys can access any client in any org
return next();
}
if (!req.apiKeyOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Key does not have access to this organization"
)
);
}
if (allClientIds.length === 0) {
return next();
}
try {
const orgId = req.apiKeyOrg.orgId;
const clientsData = await db
.select()
.from(clients)
.where(
and(
inArray(clients.clientId, allClientIds),
eq(clients.orgId, orgId)
)
);
if (clientsData.length !== allClientIds.length) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Key does not have access to one or more specified clients"
)
);
}
return next();
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error checking if key has access to the specified clients"
)
);
}
}

View File

@@ -11,7 +11,9 @@ export async function verifyApiKeySetResourceUsers(
next: NextFunction
) {
const apiKey = req.apiKey;
const userIds = req.body.userIds;
const singleUserId = req.params.userId || req.body.userId || req.query.userId;
const { userIds } = req.body;
const allUserIds = userIds || (singleUserId ? [singleUserId] : []);
if (!apiKey) {
return next(
@@ -33,11 +35,7 @@ export async function verifyApiKeySetResourceUsers(
);
}
if (!userIds) {
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid user IDs"));
}
if (userIds.length === 0) {
if (allUserIds.length === 0) {
return next();
}
@@ -48,12 +46,12 @@ export async function verifyApiKeySetResourceUsers(
.from(userOrgs)
.where(
and(
inArray(userOrgs.userId, userIds),
inArray(userOrgs.userId, allUserIds),
eq(userOrgs.orgId, orgId)
)
);
if (userOrgsData.length !== userIds.length) {
if (userOrgsData.length !== allUserIds.length) {
return next(
createHttpError(
HttpCode.FORBIDDEN,

View File

@@ -13,8 +13,6 @@ export async function verifyApiKeySiteResourceAccess(
try {
const apiKey = req.apiKey;
const siteResourceId = parseInt(req.params.siteResourceId);
const siteId = parseInt(req.params.siteId);
const orgId = req.params.orgId;
if (!apiKey) {
return next(
@@ -22,11 +20,11 @@ export async function verifyApiKeySiteResourceAccess(
);
}
if (!siteResourceId || !siteId || !orgId) {
if (!siteResourceId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Missing required parameters"
"Missing siteResourceId parameter"
)
);
}
@@ -41,9 +39,7 @@ export async function verifyApiKeySiteResourceAccess(
.select()
.from(siteResources)
.where(and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
eq(siteResources.siteResourceId, siteResourceId)
))
.limit(1);
@@ -64,11 +60,11 @@ export async function verifyApiKeySiteResourceAccess(
.where(
and(
eq(apiKeyOrg.apiKeyId, apiKey.apiKeyId),
eq(apiKeyOrg.orgId, orgId)
eq(apiKeyOrg.orgId, siteResource.orgId)
)
)
.limit(1);
if (apiKeyOrgRes.length === 0) {
return next(
createHttpError(
@@ -77,12 +73,11 @@ export async function verifyApiKeySiteResourceAccess(
)
);
}
req.apiKeyOrg = apiKeyOrgRes[0];
}
// Attach the siteResource to the request for use in the next middleware/route
// @ts-ignore - Extending Request type
req.siteResource = siteResource;
return next();

View File

@@ -5,6 +5,7 @@ import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { canUserAccessResource } from "@server/auth/canUserAccessResource";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifyAccessTokenAccess(
req: Request,
@@ -96,6 +97,24 @@ export async function verifyAccessTokenAccess(
req.userOrgId = resource[0].orgId!;
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
const resourceAllowed = await canUserAccessResource({
userId,
resourceId,

View File

@@ -4,6 +4,7 @@ import { roles, userOrgs } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifyAdmin(
req: Request,
@@ -43,6 +44,24 @@ export async function verifyAdmin(
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
const userRole = await db
.select()
.from(roles)

View File

@@ -4,6 +4,7 @@ import { userOrgs, apiKeys, apiKeyOrg } from "@server/db";
import { and, eq, or } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifyApiKeyAccess(
req: Request,
@@ -84,6 +85,24 @@ export async function verifyApiKeyAccess(
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;

View File

@@ -4,6 +4,7 @@ import { userOrgs, clients, roleClients, userClients } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifyClientAccess(
req: Request,
@@ -75,6 +76,24 @@ export async function verifyClientAccess(
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgId = client.orgId;

View File

@@ -1,29 +0,0 @@
import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import config from "@server/lib/config";
export async function verifyClientsEnabled(
req: Request,
res: Response,
next: NextFunction
) {
try {
if (!config.getRawConfig().flags?.enable_clients) {
return next(
createHttpError(
HttpCode.NOT_IMPLEMENTED,
"Clients are not enabled on this server."
)
);
}
return next();
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to check if clients are enabled"
)
);
}
}

View File

@@ -4,6 +4,7 @@ import { userOrgs, apiKeyOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifyDomainAccess(
req: Request,
@@ -78,6 +79,24 @@ export async function verifyDomainAccess(
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;

View File

@@ -0,0 +1,45 @@
import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { db, olms } from "@server/db";
import { and, eq } from "drizzle-orm";
export async function verifyOlmAccess(
req: Request,
res: Response,
next: NextFunction
) {
try {
const userId = req.user!.userId;
const olmId = req.params.olmId || req.body.olmId || req.query.olmId;
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
);
}
const [existingOlm] = await db
.select()
.from(olms)
.where(and(eq(olms.olmId, olmId), eq(olms.userId, userId)));
if (!existingOlm) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this olm"
)
);
}
return next();
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error checking if user has access to this user"
)
);
}
}

View File

@@ -47,22 +47,22 @@ export async function verifyOrgAccess(
);
}
const policyCheck = await checkOrgAccessPolicy({
orgId,
userId,
session: req.session
});
logger.debug("Org check policy result", { policyCheck });
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
if (req.orgPolicyAllowed === undefined) {
const policyCheck = await checkOrgAccessPolicy({
orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
// User has access, attach the user's role to the request for potential future use

View File

@@ -1,14 +1,10 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import {
resources,
userOrgs,
userResources,
roleResources,
} from "@server/db";
import { resources, userOrgs, userResources, roleResources } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifyResourceAccess(
req: Request,
@@ -73,6 +69,24 @@ export async function verifyResourceAccess(
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgId = resource[0].orgId;

View File

@@ -5,6 +5,7 @@ import { and, eq, inArray } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifyRoleAccess(
req: Request,
@@ -105,6 +106,33 @@ export async function verifyRoleAccess(
req.userOrgRoleId = userOrg[0].roleId;
}
if (!req.userOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
)
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
return next();
} catch (error) {
logger.error("Error verifying role access:", error);
@@ -116,4 +144,3 @@ export async function verifyRoleAccess(
);
}
}

View File

@@ -1,10 +1,5 @@
import { NextFunction, Response } from "express";
import ErrorResponse from "@server/types/ErrorResponse";
import { db } from "@server/db";
import { users } from "@server/db";
import { eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { verifySession } from "@server/auth/sessions/verifySession";
import { unauthorized } from "@server/auth/unauthorizedResponse";
@@ -13,24 +8,15 @@ export const verifySessionMiddleware = async (
res: Response<ErrorResponse>,
next: NextFunction
) => {
const { session, user } = await verifySession(req);
const { forceLogin } = req.query;
const { session, user } = await verifySession(req, forceLogin === "true");
if (!session || !user) {
return next(unauthorized());
}
const existingUser = await db
.select()
.from(users)
.where(eq(users.userId, user.userId));
if (!existingUser || !existingUser[0]) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "User does not exist")
);
}
req.user = existingUser[0];
req.user = user;
req.session = session;
next();
return next();
};

View File

@@ -0,0 +1,90 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { clients } from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifySetResourceClients(
req: Request,
res: Response,
next: NextFunction
) {
const userId = req.user!.userId;
const singleClientId =
req.params.clientId || req.body.clientId || req.query.clientId;
const { clientIds } = req.body;
const allClientIds =
clientIds ||
(singleClientId ? [parseInt(singleClientId as string)] : []);
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
);
}
if (!req.userOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
)
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
if (allClientIds.length === 0) {
return next();
}
try {
const orgId = req.userOrg.orgId;
// get all clients for the clientIds
const clientsData = await db
.select()
.from(clients)
.where(
and(
inArray(clients.clientId, allClientIds),
eq(clients.orgId, orgId)
)
);
if (clientsData.length !== allClientIds.length) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to one or more specified clients"
)
);
}
return next();
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error checking if user has access to the specified clients"
)
);
}
}

View File

@@ -4,6 +4,7 @@ import { userOrgs } from "@server/db";
import { and, eq, inArray, or } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifySetResourceUsers(
req: Request,
@@ -28,6 +29,24 @@ export async function verifySetResourceUsers(
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
if (!userIds) {
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid user IDs"));
}

View File

@@ -1,16 +1,11 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import {
sites,
userOrgs,
userSites,
roleSites,
roles,
} from "@server/db";
import { sites, userOrgs, userSites, roleSites, roles } from "@server/db";
import { and, eq, or } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifySiteAccess(
req: Request,
@@ -82,6 +77,24 @@ export async function verifySiteAccess(
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgId = site[0].orgId;

View File

@@ -1,10 +1,11 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { db, roleSiteResources, userOrgs, userSiteResources } from "@server/db";
import { siteResources } from "@server/db";
import { eq, and } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifySiteResourceAccess(
req: Request,
@@ -12,44 +13,145 @@ export async function verifySiteResourceAccess(
next: NextFunction
): Promise<any> {
try {
const siteResourceId = parseInt(req.params.siteResourceId);
const siteId = parseInt(req.params.siteId);
const orgId = req.params.orgId;
const userId = req.user!.userId;
const siteResourceId =
req.params.siteResourceId ||
req.body.siteResourceId ||
req.query.siteResourceId;
if (!siteResourceId || !siteId || !orgId) {
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
);
}
if (!siteResourceId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Missing required parameters"
"Site resource ID is required"
)
);
}
const siteResourceIdNum = parseInt(siteResourceId as string, 10);
if (isNaN(siteResourceIdNum)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid site resource ID"
)
);
}
// Check if the site resource exists and belongs to the specified site and org
const [siteResource] = await db
.select()
.from(siteResources)
.where(and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
))
.where(eq(siteResources.siteResourceId, siteResourceIdNum))
.limit(1);
if (!siteResource) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Site resource not found"
`Site resource with ID ${siteResourceIdNum} not found`
)
);
}
if (!siteResource.orgId) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Site resource with ID ${siteResourceIdNum} does not have an organization ID`
)
);
}
if (!req.userOrg) {
const userOrgRole = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
eq(userOrgs.orgId, siteResource.orgId)
)
)
.limit(1);
req.userOrg = userOrgRole[0];
}
if (!req.userOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
)
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgId = siteResource.orgId;
// Attach the siteResource to the request for use in the next middleware/route
// @ts-ignore - Extending Request type
req.siteResource = siteResource;
next();
const roleResourceAccess = await db
.select()
.from(roleSiteResources)
.where(
and(
eq(roleSiteResources.siteResourceId, siteResourceIdNum),
eq(roleSiteResources.roleId, userOrgRoleId)
)
)
.limit(1);
if (roleResourceAccess.length > 0) {
return next();
}
const userResourceAccess = await db
.select()
.from(userSiteResources)
.where(
and(
eq(userSiteResources.userId, userId),
eq(userSiteResources.siteResourceId, siteResourceIdNum)
)
)
.limit(1);
if (userResourceAccess.length > 0) {
return next();
}
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this resource"
)
);
} catch (error) {
logger.error("Error verifying site resource access:", error);
return next(

View File

@@ -5,6 +5,7 @@ import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { canUserAccessResource } from "../auth/canUserAccessResource";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifyTargetAccess(
req: Request,
@@ -102,6 +103,26 @@ export async function verifyTargetAccess(
req.userOrgId = resource[0].orgId!;
}
const orgId = req.userOrg.orgId;
if (req.orgPolicyAllowed === undefined && orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
const resourceAllowed = await canUserAccessResource({
userId,
resourceId,

View File

@@ -15,7 +15,9 @@ export const verifySessionUserMiddleware = async (
res: Response<ErrorResponse>,
next: NextFunction
) => {
const { session, user } = await verifySession(req);
const { forceLogin } = req.query;
const { session, user } = await verifySession(req, forceLogin === "true");
if (!session || !user) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(`User session not found. IP: ${req.ip}.`);

View File

@@ -4,6 +4,7 @@ import { userOrgs } from "@server/db";
import { and, eq, or } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifyUserAccess(
req: Request,
@@ -47,6 +48,24 @@ export async function verifyUserAccess(
);
}
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
const policyCheck = await checkOrgAccessPolicy({
orgId: req.userOrg.orgId,
userId,
session: req.session
});
req.orgPolicyAllowed = policyCheck.allowed;
if (!policyCheck.allowed || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
}
return next();
} catch (error) {
return next(

View File

@@ -45,6 +45,11 @@ export class PrivateConfig {
this.rawPrivateConfig = parsedPrivateConfig;
process.env.BRANDING_HIDE_AUTH_LAYOUT_FOOTER =
this.rawPrivateConfig.branding?.hide_auth_layout_footer === true
? "true"
: "false";
if (this.rawPrivateConfig.branding?.colors) {
process.env.BRANDING_COLORS = JSON.stringify(
this.rawPrivateConfig.branding?.colors

View File

@@ -197,7 +197,7 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
// // set the item in the database if it is offline
// if (isActuallyOnline != node.online) {
// await db
// await trx
// .update(exitNodes)
// .set({ online: isActuallyOnline })
// .where(eq(exitNodes.exitNodeId, node.exitNodeId));

363
server/private/lib/lock.ts Normal file
View File

@@ -0,0 +1,363 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { config } from "@server/lib/config";
import logger from "@server/logger";
import { redis } from "#private/lib/redis";
export class LockManager {
/**
* Acquire a distributed lock using Redis SET with NX and PX options
* @param lockKey - Unique identifier for the lock
* @param ttlMs - Time to live in milliseconds
* @returns Promise<boolean> - true if lock acquired, false otherwise
*/
async acquireLock(
lockKey: string,
ttlMs: number = 30000
): Promise<boolean> {
if (!redis || !redis.status || redis.status !== "ready") {
return true;
}
const lockValue = `${
config.getRawConfig().gerbil.exit_node_name
}:${Date.now()}`;
const redisKey = `lock:${lockKey}`;
try {
// Use SET with NX (only set if not exists) and PX (expire in milliseconds)
// This is atomic and handles both setting and expiration
const result = await redis.set(
redisKey,
lockValue,
"PX",
ttlMs,
"NX"
);
if (result === "OK") {
logger.debug(
`Lock acquired: ${lockKey} by ${
config.getRawConfig().gerbil.exit_node_name
}`
);
return true;
}
// Check if the existing lock is from this worker (reentrant behavior)
const existingValue = await redis.get(redisKey);
if (
existingValue &&
existingValue.startsWith(
`${config.getRawConfig().gerbil.exit_node_name}:`
)
) {
// Extend the lock TTL since it's the same worker
await redis.pexpire(redisKey, ttlMs);
logger.debug(
`Lock extended: ${lockKey} by ${
config.getRawConfig().gerbil.exit_node_name
}`
);
return true;
}
return false;
} catch (error) {
logger.error(`Failed to acquire lock ${lockKey}:`, error);
return false;
}
}
/**
* Release a lock using Lua script to ensure atomicity
* @param lockKey - Unique identifier for the lock
*/
async releaseLock(lockKey: string): Promise<void> {
if (!redis || !redis.status || redis.status !== "ready") {
return;
}
const redisKey = `lock:${lockKey}`;
// Lua script to ensure we only delete the lock if it belongs to this worker
const luaScript = `
local key = KEYS[1]
local worker_prefix = ARGV[1]
local current_value = redis.call('GET', key)
if current_value and string.find(current_value, worker_prefix, 1, true) == 1 then
return redis.call('DEL', key)
else
return 0
end
`;
try {
const result = (await redis.eval(
luaScript,
1,
redisKey,
`${config.getRawConfig().gerbil.exit_node_name}:`
)) as number;
if (result === 1) {
logger.debug(
`Lock released: ${lockKey} by ${
config.getRawConfig().gerbil.exit_node_name
}`
);
} else {
logger.warn(
`Lock not released - not owned by worker: ${lockKey} by ${
config.getRawConfig().gerbil.exit_node_name
}`
);
}
} catch (error) {
logger.error(`Failed to release lock ${lockKey}:`, error);
}
}
/**
* Force release a lock regardless of owner (use with caution)
* @param lockKey - Unique identifier for the lock
*/
async forceReleaseLock(lockKey: string): Promise<void> {
if (!redis || !redis.status || redis.status !== "ready") {
return;
}
const redisKey = `lock:${lockKey}`;
try {
const result = await redis.del(redisKey);
if (result === 1) {
logger.debug(`Lock force released: ${lockKey}`);
}
} catch (error) {
logger.error(`Failed to force release lock ${lockKey}:`, error);
}
}
/**
* Check if a lock exists and get its info
* @param lockKey - Unique identifier for the lock
* @returns Promise<{exists: boolean, ownedByMe: boolean, ttl: number}>
*/
async getLockInfo(lockKey: string): Promise<{
exists: boolean;
ownedByMe: boolean;
ttl: number;
owner?: string;
}> {
if (!redis || !redis.status || redis.status !== "ready") {
return { exists: false, ownedByMe: true, ttl: 0 };
}
const redisKey = `lock:${lockKey}`;
try {
const [value, ttl] = await Promise.all([
redis.get(redisKey),
redis.pttl(redisKey)
]);
const exists = value !== null;
const ownedByMe =
exists &&
value!.startsWith(`${config.getRawConfig().gerbil.exit_node_name}:`);
const owner = exists ? value!.split(":")[0] : undefined;
return {
exists,
ownedByMe,
ttl: ttl > 0 ? ttl : 0,
owner
};
} catch (error) {
logger.error(`Failed to get lock info ${lockKey}:`, error);
return { exists: false, ownedByMe: false, ttl: 0 };
}
}
/**
* Extend the TTL of an existing lock owned by this worker
* @param lockKey - Unique identifier for the lock
* @param ttlMs - New TTL in milliseconds
* @returns Promise<boolean> - true if extended successfully
*/
async extendLock(lockKey: string, ttlMs: number): Promise<boolean> {
if (!redis || !redis.status || redis.status !== "ready") {
return true;
}
const redisKey = `lock:${lockKey}`;
// Lua script to extend TTL only if lock is owned by this worker
const luaScript = `
local key = KEYS[1]
local worker_prefix = ARGV[1]
local ttl = tonumber(ARGV[2])
local current_value = redis.call('GET', key)
if current_value and string.find(current_value, worker_prefix, 1, true) == 1 then
return redis.call('PEXPIRE', key, ttl)
else
return 0
end
`;
try {
const result = (await redis.eval(
luaScript,
1,
redisKey,
`${config.getRawConfig().gerbil.exit_node_name}:`,
ttlMs.toString()
)) as number;
if (result === 1) {
logger.debug(
`Lock extended: ${lockKey} by ${
config.getRawConfig().gerbil.exit_node_name
} for ${ttlMs}ms`
);
return true;
}
return false;
} catch (error) {
logger.error(`Failed to extend lock ${lockKey}:`, error);
return false;
}
}
/**
* Attempt to acquire lock with retries and exponential backoff
* @param lockKey - Unique identifier for the lock
* @param ttlMs - Time to live in milliseconds
* @param maxRetries - Maximum number of retry attempts
* @param baseDelayMs - Base delay between retries in milliseconds
* @returns Promise<boolean> - true if lock acquired
*/
async acquireLockWithRetry(
lockKey: string,
ttlMs: number = 30000,
maxRetries: number = 5,
baseDelayMs: number = 100
): Promise<boolean> {
if (!redis || !redis.status || redis.status !== "ready") {
return true;
}
for (let attempt = 0; attempt <= maxRetries; attempt++) {
const acquired = await this.acquireLock(lockKey, ttlMs);
if (acquired) {
return true;
}
if (attempt < maxRetries) {
// Exponential backoff with jitter
const delay =
baseDelayMs * Math.pow(2, attempt) + Math.random() * 100;
await new Promise((resolve) => setTimeout(resolve, delay));
}
}
logger.warn(
`Failed to acquire lock ${lockKey} after ${maxRetries + 1} attempts`
);
return false;
}
/**
* Execute a function while holding a lock
* @param lockKey - Unique identifier for the lock
* @param fn - Function to execute while holding the lock
* @param ttlMs - Lock TTL in milliseconds
* @returns Promise<T> - Result of the executed function
*/
async withLock<T>(
lockKey: string,
fn: () => Promise<T>,
ttlMs: number = 30000
): Promise<T> {
if (!redis || !redis.status || redis.status !== "ready") {
return await fn();
}
const acquired = await this.acquireLock(lockKey, ttlMs);
if (!acquired) {
throw new Error(`Failed to acquire lock: ${lockKey}`);
}
try {
return await fn();
} finally {
await this.releaseLock(lockKey);
}
}
/**
* Clean up expired locks - Redis handles this automatically, but this method
* can be used to get statistics about locks
* @returns Promise<{activeLocksCount: number, locksOwnedByMe: number}>
*/
async getLockStatistics(): Promise<{
activeLocksCount: number;
locksOwnedByMe: number;
}> {
if (!redis || !redis.status || redis.status !== "ready") {
return { activeLocksCount: 0, locksOwnedByMe: 0 };
}
try {
const keys = await redis.keys("lock:*");
let locksOwnedByMe = 0;
if (keys.length > 0) {
const values = await redis.mget(...keys);
locksOwnedByMe = values.filter(
(value) =>
value &&
value.startsWith(
`${config.getRawConfig().gerbil.exit_node_name}:`
)
).length;
}
return {
activeLocksCount: keys.length,
locksOwnedByMe
};
} catch (error) {
logger.error("Failed to get lock statistics:", error);
return { activeLocksCount: 0, locksOwnedByMe: 0 };
}
}
/**
* Close the Redis connection
*/
async disconnect(): Promise<void> {
if (!redis || !redis.status || redis.status !== "ready") {
return;
}
await redis.quit();
}
}
export const lockManager = new LockManager();

View File

@@ -124,6 +124,7 @@ export const privateConfigSchema = z.object({
})
)
.optional(),
hide_auth_layout_footer: z.boolean().optional().default(false),
login_page: z
.object({
subtitle_text: z.string().optional(),

View File

@@ -434,9 +434,9 @@ export async function getTraefikConfig(
routerMiddlewares.push(rewriteMiddlewareName);
}
logger.debug(
`Created path rewrite middleware ${rewriteMiddlewareName}: ${resource.pathMatchType}(${resource.path}) -> ${resource.rewritePathType}(${resource.rewritePath})`
);
// logger.debug(
// `Created path rewrite middleware ${rewriteMiddlewareName}: ${resource.pathMatchType}(${resource.path}) -> ${resource.rewritePathType}(${resource.rewritePath})`
// );
} catch (error) {
logger.error(
`Failed to create path rewrite middleware for resource ${resource.resourceId}: ${error}`

View File

@@ -16,4 +16,5 @@ export * from "./verifyRemoteExitNodeAccess";
export * from "./verifyIdpAccess";
export * from "./verifyLoginPageAccess";
export * from "./logActionAudit";
export * from "./verifySubscription";
export * from "./verifySubscription";
export * from "./verifyValidLicense";

View File

@@ -31,7 +31,6 @@ import {
verifyUserIsServerAdmin,
verifySiteAccess,
verifyClientAccess,
verifyClientsEnabled,
} from "@server/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import {
@@ -437,7 +436,6 @@ authenticated.get(
authenticated.post(
"/re-key/:clientId/regenerate-client-secret",
verifyClientsEnabled,
verifyClientAccess,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateClientSecret

View File

@@ -1043,7 +1043,7 @@ hybridRouter.get(
);
}
let rules = await db
const rules = await db
.select()
.from(resourceRules)
.where(eq(resourceRules.resourceId, resourceId));
@@ -1369,7 +1369,7 @@ const updateHolePunchSchema = z.object({
port: z.number(),
timestamp: z.number(),
reachableAt: z.string().optional(),
publicKey: z.string().optional()
publicKey: z.string() // this is the client public key
});
hybridRouter.post(
"/gerbil/update-hole-punch",
@@ -1408,7 +1408,7 @@ hybridRouter.post(
);
}
const { olmId, newtId, ip, port, timestamp, token, reachableAt } =
const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } =
parsedParams.data;
const destinations = await updateAndGenerateEndpointDestinations(
@@ -1418,6 +1418,7 @@ hybridRouter.post(
port,
timestamp,
token,
publicKey,
exitNode,
true
);
@@ -1742,7 +1743,12 @@ hybridRouter.post(
tls: logEntry.tls
}));
await db.insert(requestAuditLog).values(logEntries);
// batch them into inserts of 100 to avoid exceeding parameter limits
const batchSize = 100;
for (let i = 0; i < logEntries.length; i += batchSize) {
const batch = logEntries.slice(i, i + batchSize);
await db.insert(requestAuditLog).values(batch);
}
return response(res, {
data: null,

View File

@@ -13,13 +13,17 @@
import * as orgIdp from "#private/routers/orgIdp";
import * as org from "#private/routers/org";
import * as logs from "#private/routers/auditLogs";
import { Router } from "express";
import {
verifyApiKey,
verifyApiKeyHasAction,
verifyApiKeyIsRoot,
verifyApiKeyOrgAccess,
} from "@server/middlewares";
import {
verifyValidSubscription,
verifyValidLicense
} from "#private/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import { unauthenticated as ua, authenticated as a } from "@server/routers/integration";
@@ -42,4 +46,42 @@ authenticated.delete(
verifyApiKeyHasAction(ActionsEnum.deleteIdp),
logActionAudit(ActionsEnum.deleteIdp),
orgIdp.deleteOrgIdp,
);
);
authenticated.get(
"/org/:orgId/logs/action",
verifyValidLicense,
verifyValidSubscription,
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs),
logs.queryActionAuditLogs
);
authenticated.get(
"/org/:orgId/logs/action/export",
verifyValidLicense,
verifyValidSubscription,
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs),
logActionAudit(ActionsEnum.exportLogs),
logs.exportActionAuditLogs
);
authenticated.get(
"/org/:orgId/logs/access",
verifyValidLicense,
verifyValidSubscription,
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs),
logs.queryAccessAuditLogs
);
authenticated.get(
"/org/:orgId/logs/access/export",
verifyValidLicense,
verifyValidSubscription,
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs),
logActionAudit(ActionsEnum.exportLogs),
logs.exportAccessAuditLogs
);

View File

@@ -164,7 +164,7 @@ export async function createLoginPage(
.select()
.from(exitNodes)
.where(and(eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
.limit(10);
.limit(10);
}
// select a random exit node

View File

@@ -38,6 +38,7 @@ import { rateLimitService } from "#private/lib/rateLimit";
import { messageHandlers } from "@server/routers/ws/messageHandlers";
import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers";
import { AuthenticatedWebSocket, ClientType, WSMessage, TokenPayload, WebSocketRequest, RedisMessage } from "@server/routers/ws";
import { validateSessionToken } from "@server/auth/sessions/app";
// Merge public and private message handlers
Object.assign(messageHandlers, privateMessageHandlers);
@@ -370,6 +371,9 @@ const sendToClientLocal = async (
client.send(messageString);
}
});
logger.debug(`sendToClient: Message type ${message.type} sent to clientId ${clientId}`);
return true;
};
@@ -478,7 +482,8 @@ const getActiveNodes = async (
// Token verification middleware
const verifyToken = async (
token: string,
clientType: ClientType
clientType: ClientType,
userToken: string
): Promise<TokenPayload | null> => {
try {
if (clientType === "newt") {
@@ -506,6 +511,17 @@ const verifyToken = async (
if (!existingOlm || !existingOlm[0]) {
return null;
}
if (olm.userId) { // this is a user device and we need to check the user token
const { session: userSession, user } = await validateSessionToken(userToken);
if (!userSession || !user) {
return null;
}
if (user.userId !== olm.userId) {
return null;
}
}
return { client: existingOlm[0], session, clientType };
} else if (clientType === "remoteExitNode") {
const { session, remoteExitNode } =
@@ -652,6 +668,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
url.searchParams.get("token") ||
request.headers["sec-websocket-protocol"] ||
"";
const userToken = url.searchParams.get('userToken') || '';
let clientType = url.searchParams.get(
"clientType"
) as ClientType;
@@ -673,7 +690,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
return;
}
const tokenPayload = await verifyToken(token, clientType);
const tokenPayload = await verifyToken(token, clientType, userToken);
if (!tokenPayload) {
logger.debug(
"Unauthorized connection attempt: invalid token..."
@@ -792,6 +809,28 @@ if (redisManager.isRedisEnabled()) {
);
}
// Disconnect a specific client and force them to reconnect
const disconnectClient = async (clientId: string): Promise<boolean> => {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
if (!clients || clients.length === 0) {
logger.debug(`No connections found for client ID: ${clientId}`);
return false;
}
logger.info(`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`);
// Close all connections for this client
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.close(1000, "Disconnected by server");
}
});
return true;
};
// Cleanup function for graceful shutdown
const cleanup = async (): Promise<void> => {
try {
@@ -829,6 +868,7 @@ export {
connectedClients,
hasActiveConnections,
getActiveNodes,
disconnectClient,
NODE_ID,
cleanup
};

View File

@@ -131,7 +131,7 @@ export function queryRequest(data: Q) {
eq(requestAuditLog.resourceId, resources.resourceId)
) // TODO: Is this efficient?
.where(getWhere(data))
.orderBy(desc(requestAuditLog.timestamp), desc(requestAuditLog.id));
.orderBy(desc(requestAuditLog.timestamp));
}
export function countRequestQuery(data: Q) {

View File

@@ -13,4 +13,7 @@ export * from "./initialSetupComplete";
export * from "./validateSetupToken";
export * from "./changePassword";
export * from "./checkResourceSession";
export * from "./securityKey";
export * from "./securityKey";
export * from "./startDeviceWebAuth";
export * from "./verifyDeviceWebAuth";
export * from "./pollDeviceWebAuth";

View File

@@ -1,7 +1,9 @@
import {
createSession,
generateSessionToken,
serializeSessionCookie
invalidateSession,
serializeSessionCookie,
SESSION_COOKIE_NAME
} from "@server/auth/sessions/app";
import { db, resources } from "@server/db";
import { users, securityKeys } from "@server/db";
@@ -21,11 +23,11 @@ import { UserType } from "@server/types/UserTypes";
import { logAccessAudit } from "#dynamic/lib/logAccessAudit";
export const loginBodySchema = z.strictObject({
email: z.email().toLowerCase(),
password: z.string(),
code: z.string().optional(),
resourceGuid: z.string().optional()
});
email: z.email().toLowerCase(),
password: z.string(),
code: z.string().optional(),
resourceGuid: z.string().optional()
});
export type LoginBody = z.infer<typeof loginBodySchema>;
@@ -41,6 +43,21 @@ export async function login(
res: Response,
next: NextFunction
): Promise<any> {
const { forceLogin } = req.query;
const { session: existingSession } = await verifySession(
req,
forceLogin === "true"
);
if (existingSession) {
return response<null>(res, {
data: null,
success: true,
error: false,
message: "Already logged in",
status: HttpCode.OK
});
}
const parsedBody = loginBodySchema.safeParse(req.body);
if (!parsedBody.success) {
@@ -55,17 +72,6 @@ export async function login(
const { email, password, code, resourceGuid } = parsedBody.data;
try {
const { session: existingSession } = await verifySession(req);
if (existingSession) {
return response<null>(res, {
data: null,
success: true,
error: false,
message: "Already logged in",
status: HttpCode.OK
});
}
let resourceId: number | null = null;
let orgId: string | null = null;
if (resourceGuid) {
@@ -225,6 +231,12 @@ export async function login(
}
}
// check for previous cookie value and expire it
const previousCookie = req.cookies[SESSION_COOKIE_NAME];
if (previousCookie) {
await invalidateSession(previousCookie);
}
const token = generateSessionToken();
const sess = await createSession(token, existingUser.userId);
const isSecure = req.protocol === "https";

View File

@@ -0,0 +1,168 @@
import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
import { response } from "@server/lib/response";
import { db, deviceWebAuthCodes } from "@server/db";
import { eq, and, gt } from "drizzle-orm";
import {
createSession,
generateSessionToken
} from "@server/auth/sessions/app";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
const paramsSchema = z.object({
code: z.string().min(1, "Code is required")
});
export type PollDeviceWebAuthParams = z.infer<typeof paramsSchema>;
// Helper function to hash device code before querying database
function hashDeviceCode(code: string): string {
return encodeHexLowerCase(
sha256(new TextEncoder().encode(code))
);
}
export type PollDeviceWebAuthResponse = {
verified: boolean;
token?: string;
};
// Helper function to extract IP from request (same as in startDeviceWebAuth)
function extractIpFromRequest(req: Request): string | undefined {
const ip = req.ip || req.socket.remoteAddress;
if (!ip) {
return undefined;
}
// Handle IPv6 format [::1] or IPv4 format
if (ip.startsWith("[") && ip.includes("]")) {
const ipv6Match = ip.match(/\[(.*?)\]/);
if (ipv6Match) {
return ipv6Match[1];
}
}
// Handle IPv4 with port (split at last colon)
const lastColonIndex = ip.lastIndexOf(":");
if (lastColonIndex !== -1) {
return ip.substring(0, lastColonIndex);
}
return ip;
}
export async function pollDeviceWebAuth(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
try {
const { code } = parsedParams.data;
const now = Date.now();
const requestIp = extractIpFromRequest(req);
// Hash the code before querying
const hashedCode = hashDeviceCode(code);
// Find the code in the database
const [deviceCode] = await db
.select()
.from(deviceWebAuthCodes)
.where(eq(deviceWebAuthCodes.code, hashedCode))
.limit(1);
if (!deviceCode) {
return response<PollDeviceWebAuthResponse>(res, {
data: {
verified: false
},
success: true,
error: false,
message: "Code not found",
status: HttpCode.OK
});
}
// Check if code is expired
if (deviceCode.expiresAt <= now) {
return response<PollDeviceWebAuthResponse>(res, {
data: {
verified: false
},
success: true,
error: false,
message: "Code expired",
status: HttpCode.OK
});
}
// Check if code is verified
if (!deviceCode.verified) {
return response<PollDeviceWebAuthResponse>(res, {
data: {
verified: false
},
success: true,
error: false,
message: "Code not yet verified",
status: HttpCode.OK
});
}
// Check if userId is set (should be set when verified)
if (!deviceCode.userId) {
logger.error("Device code is verified but userId is missing", { codeId: deviceCode.codeId });
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Invalid code state"
)
);
}
// Generate session token
const token = generateSessionToken();
await createSession(token, deviceCode.userId);
// Delete the code after successful exchange for a token
await db
.delete(deviceWebAuthCodes)
.where(eq(deviceWebAuthCodes.codeId, deviceCode.codeId));
return response<PollDeviceWebAuthResponse>(res, {
data: {
verified: true,
token
},
success: true,
error: false,
message: "Code verified and session created",
status: HttpCode.OK
});
} catch (e) {
logger.error(e);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to poll device code"
)
);
}
}

View File

@@ -52,7 +52,7 @@ setInterval(async () => {
await db
.delete(webauthnChallenge)
.where(lt(webauthnChallenge.expiresAt, now));
logger.debug("Cleaned up expired security key challenges");
// logger.debug("Cleaned up expired security key challenges");
} catch (error) {
logger.error("Failed to clean up expired security key challenges", error);
}

View File

@@ -0,0 +1,156 @@
import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
import { response } from "@server/lib/response";
import { db, deviceWebAuthCodes } from "@server/db";
import { alphabet, generateRandomString } from "oslo/crypto";
import { createDate } from "oslo";
import { TimeSpan } from "oslo";
import { maxmindLookup } from "@server/db/maxmind";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
const bodySchema = z.object({
deviceName: z.string().optional(),
applicationName: z.string().min(1, "Application name is required")
}).strict();
export type StartDeviceWebAuthBody = z.infer<typeof bodySchema>;
export type StartDeviceWebAuthResponse = {
code: string;
expiresInSeconds: number;
};
// Helper function to generate device code in format A1AJ-N5JD
function generateDeviceCode(): string {
const part1 = generateRandomString(4, alphabet("A-Z", "0-9"));
const part2 = generateRandomString(4, alphabet("A-Z", "0-9"));
return `${part1}-${part2}`;
}
// Helper function to hash device code before storing in database
function hashDeviceCode(code: string): string {
return encodeHexLowerCase(
sha256(new TextEncoder().encode(code))
);
}
// Helper function to extract IP from request
function extractIpFromRequest(req: Request): string | undefined {
const ip = req.ip || req.socket.remoteAddress;
if (!ip) {
return undefined;
}
// Handle IPv6 format [::1] or IPv4 format
if (ip.startsWith("[") && ip.includes("]")) {
const ipv6Match = ip.match(/\[(.*?)\]/);
if (ipv6Match) {
return ipv6Match[1];
}
}
// Handle IPv4 with port (split at last colon)
const lastColonIndex = ip.lastIndexOf(":");
if (lastColonIndex !== -1) {
return ip.substring(0, lastColonIndex);
}
return ip;
}
// Helper function to get city from IP (if available)
async function getCityFromIp(ip: string): Promise<string | undefined> {
try {
if (!maxmindLookup) {
return undefined;
}
const result = maxmindLookup.get(ip);
if (!result) {
return undefined;
}
// MaxMind CountryResponse doesn't include city by default
// If city data is available, it would be in result.city?.names?.en
// But since we're using CountryResponse type, we'll just return undefined
// The user said "don't do this if not easy", so we'll skip city for now
return undefined;
} catch (error) {
logger.debug("Failed to get city from IP", error);
return undefined;
}
}
export async function startDeviceWebAuth(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
try {
const { deviceName, applicationName } = parsedBody.data;
// Generate device code
const code = generateDeviceCode();
// Hash the code before storing in database
const hashedCode = hashDeviceCode(code);
// Extract IP from request
const ip = extractIpFromRequest(req);
// Get city (optional, may return undefined)
const city = ip ? await getCityFromIp(ip) : undefined;
// Set expiration to 5 minutes from now
const expiresAt = createDate(new TimeSpan(5, "m")).getTime();
// Insert into database (store hashed code)
await db.insert(deviceWebAuthCodes).values({
code: hashedCode,
ip: ip || null,
city: city || null,
deviceName: deviceName || null,
applicationName,
expiresAt,
createdAt: Date.now()
});
// calculate relative expiration in seconds
const expiresInSeconds = Math.floor((expiresAt - Date.now()) / 1000);
return response<StartDeviceWebAuthResponse>(res, {
data: {
code,
expiresInSeconds
},
success: true,
error: false,
message: "Device web auth code generated",
status: HttpCode.OK
});
} catch (e) {
logger.error(e);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to start device web auth"
)
);
}
}

View File

@@ -0,0 +1,180 @@
import { Request, Response, NextFunction } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
import { response } from "@server/lib/response";
import { db, deviceWebAuthCodes, sessions } from "@server/db";
import { eq, and, gt } from "drizzle-orm";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { unauthorized } from "@server/auth/unauthorizedResponse";
const bodySchema = z
.object({
code: z.string().min(1, "Code is required"),
verify: z.boolean().optional().default(false) // If false, just check and return metadata
})
.strict();
// Helper function to hash device code before querying database
function hashDeviceCode(code: string): string {
return encodeHexLowerCase(sha256(new TextEncoder().encode(code)));
}
export type VerifyDeviceWebAuthBody = z.infer<typeof bodySchema>;
export type VerifyDeviceWebAuthResponse = {
success: boolean;
message: string;
metadata?: {
ip: string | null;
city: string | null;
deviceName: string | null;
applicationName: string;
createdAt: number;
};
};
export async function verifyDeviceWebAuth(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
const { user, session } = req;
if (!user || !session) {
return next(createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized"));
}
if (session.deviceAuthUsed) {
return next(
createHttpError(
HttpCode.UNAUTHORIZED,
"Device web auth code already used for this session"
)
);
}
if (!session.issuedAt) {
return next(
createHttpError(
HttpCode.UNAUTHORIZED,
"Session issuedAt timestamp missing"
)
);
}
// make sure sessions is not older than 5 minutes
const now = Date.now();
if (now - session.issuedAt > 5 * 60 * 1000) {
return next(
createHttpError(
HttpCode.UNAUTHORIZED,
"Session is too old to verify device web auth code"
)
);
}
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
try {
const { code, verify } = parsedBody.data;
const now = Date.now();
logger.debug("Verifying device web auth code:", { code });
// Hash the code before querying
const hashedCode = hashDeviceCode(code);
// Find the code in the database that is not expired and not already verified
const [deviceCode] = await db
.select()
.from(deviceWebAuthCodes)
.where(
and(
eq(deviceWebAuthCodes.code, hashedCode),
gt(deviceWebAuthCodes.expiresAt, now),
eq(deviceWebAuthCodes.verified, false)
)
)
.limit(1);
logger.debug("Device code lookup result:", deviceCode);
if (!deviceCode) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid, expired, or already verified code"
)
);
}
// If verify is false, just return metadata without verifying
if (!verify) {
return response<VerifyDeviceWebAuthResponse>(res, {
data: {
success: true,
message: "Code is valid",
metadata: {
ip: deviceCode.ip,
city: deviceCode.city,
deviceName: deviceCode.deviceName,
applicationName: deviceCode.applicationName,
createdAt: deviceCode.createdAt
}
},
success: true,
error: false,
message: "Code validation successful",
status: HttpCode.OK
});
}
// Update the code to mark it as verified and store the user who verified it
await db
.update(deviceWebAuthCodes)
.set({
verified: true,
userId: req.user!.userId
})
.where(eq(deviceWebAuthCodes.codeId, deviceCode.codeId));
// Also update the session to mark that device auth was used
await db
.update(sessions)
.set({
deviceAuthUsed: true
})
.where(eq(sessions.sessionId, session.sessionId));
return response<VerifyDeviceWebAuthResponse>(res, {
data: {
success: true,
message: "Device code verified successfully"
},
success: true,
error: false,
message: "Device code verified successfully",
status: HttpCode.OK
});
} catch (e) {
logger.error(e);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to verify device code"
)
);
}
}

View File

@@ -69,9 +69,9 @@ export async function cleanUpOldLogs(orgId: string, retentionDays: number) {
)
);
logger.debug(
`Cleaned up request audit logs older than ${retentionDays} days`
);
// logger.debug(
// `Cleaned up request audit logs older than ${retentionDays} days`
// );
} catch (error) {
logger.error("Error cleaning up old request audit logs:", error);
}

View File

@@ -8,8 +8,6 @@ import {
roleClients,
userClients,
olms,
clientSites,
exitNodes,
orgs,
sites
} from "@server/db";
@@ -21,23 +19,24 @@ import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import moment from "moment";
import { hashPassword } from "@server/auth/password";
import { isValidCIDR, isValidIP } from "@server/lib/validators";
import { isValidIP } from "@server/lib/validators";
import { isIpInCidr } from "@server/lib/ip";
import { OpenAPITags, registry } from "@server/openApi";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import { generateId } from "@server/auth/sessions/app";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
const createClientParamsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
const createClientSchema = z.strictObject({
name: z.string().min(1).max(255),
siteIds: z.array(z.int().positive()),
olmId: z.string(),
secret: z.string(),
subnet: z.string(),
type: z.enum(["olm"])
});
name: z.string().min(1).max(255),
olmId: z.string(),
secret: z.string(),
subnet: z.string(),
type: z.enum(["olm"])
});
export type CreateClientBody = z.infer<typeof createClientSchema>;
@@ -46,7 +45,7 @@ export type CreateClientResponse = Client;
registry.registerPath({
method: "put",
path: "/org/{orgId}/client",
description: "Create a new client.",
description: "Create a new client for an organization.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
params: createClientParamsSchema,
@@ -77,7 +76,7 @@ export async function createClient(
);
}
const { name, type, siteIds, olmId, secret, subnet } = parsedBody.data;
const { name, type, olmId, secret, subnet } = parsedBody.data;
const parsedParams = createClientParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
@@ -172,75 +171,90 @@ export async function createClient(
);
}
// check if the olmId already exists
const [existingOlm] = await db
.select()
.from(olms)
.where(eq(olms.olmId, olmId))
.limit(1);
if (existingOlm) {
return next(
createHttpError(
HttpCode.CONFLICT,
`OLM with ID ${olmId} already exists`
)
);
}
let newClient: Client | null = null;
await db.transaction(async (trx) => {
// TODO: more intelligent way to pick the exit node
const exitNodesList = await listExitNodes(orgId);
const randomExitNode =
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
const adminRole = await trx
const [adminRole] = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (adminRole.length === 0) {
trx.rollback();
if (!adminRole) {
return next(
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
);
}
const [newClient] = await trx
[newClient] = await trx
.insert(clients)
.values({
exitNodeId: randomExitNode.exitNodeId,
orgId,
name,
subnet: updatedSubnet,
type
type,
olmId // this is to lock it to a specific olm even if the olm moves across clients
})
.returning();
await trx.insert(roleClients).values({
roleId: adminRole[0].roleId,
roleId: adminRole.roleId,
clientId: newClient.clientId
});
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the site
if (req.user && req.userOrgRoleId != adminRole.roleId) {
// make sure the user can access the client
trx.insert(userClients).values({
userId: req.user?.userId!,
userId: req.user.userId,
clientId: newClient.clientId
});
}
// Create site to client associations
if (siteIds && siteIds.length > 0) {
await trx.insert(clientSites).values(
siteIds.map((siteId) => ({
clientId: newClient.clientId,
siteId
}))
);
let secretToUse = secret;
if (!secretToUse) {
secretToUse = generateId(48);
}
const secretHash = await hashPassword(secret);
const secretHash = await hashPassword(secretToUse);
await trx.insert(olms).values({
olmId,
secretHash,
name,
clientId: newClient.clientId,
dateCreated: moment().toISOString()
});
return response<CreateClientResponse>(res, {
data: newClient,
success: true,
error: false,
message: "Site created successfully",
status: HttpCode.CREATED
});
await rebuildClientAssociationsFromClient(newClient, trx);
});
return response<CreateClientResponse>(res, {
data: newClient,
success: true,
error: false,
message: "Site created successfully",
status: HttpCode.CREATED
});
} catch (error) {
logger.error(error);

View File

@@ -0,0 +1,253 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import {
roles,
Client,
clients,
roleClients,
userClients,
olms,
orgs,
sites
} from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import { isValidIP } from "@server/lib/validators";
import { isIpInCidr } from "@server/lib/ip";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
const paramsSchema = z
.object({
orgId: z.string(),
userId: z.string()
})
.strict();
const bodySchema = z
.object({
name: z.string().min(1).max(255),
olmId: z.string(),
subnet: z.string(),
type: z.enum(["olm"])
})
.strict();
export type CreateClientAndOlmBody = z.infer<typeof bodySchema>;
export type CreateClientAndOlmResponse = Client;
registry.registerPath({
method: "put",
path: "/org/{orgId}/user/{userId}/client",
description:
"Create a new client for a user and associate it with an existing olm.",
tags: [OpenAPITags.Client, OpenAPITags.Org, OpenAPITags.User],
request: {
params: paramsSchema,
body: {
content: {
"application/json": {
schema: bodySchema
}
}
}
},
responses: {}
});
export async function createUserClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { name, type, olmId, subnet } = parsedBody.data;
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId, userId } = parsedParams.data;
if (!isValidIP(subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid subnet format. Please provide a valid CIDR notation."
)
);
}
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Organization with ID ${orgId} not found`
)
);
}
if (!org.subnet) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Organization with ID ${orgId} has no subnet defined`
)
);
}
if (!isIpInCidr(subnet, org.subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"IP is not in the CIDR range of the subnet."
)
);
}
const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
// make sure the subnet is unique
const subnetExistsClients = await db
.select()
.from(clients)
.where(
and(eq(clients.subnet, updatedSubnet), eq(clients.orgId, orgId))
)
.limit(1);
if (subnetExistsClients.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${updatedSubnet} already exists in clients`
)
);
}
const subnetExistsSites = await db
.select()
.from(sites)
.where(
and(eq(sites.address, updatedSubnet), eq(sites.orgId, orgId))
)
.limit(1);
if (subnetExistsSites.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${updatedSubnet} already exists in sites`
)
);
}
// check if the olmId already exists
const [existingOlm] = await db
.select()
.from(olms)
.where(eq(olms.olmId, olmId))
.limit(1);
if (!existingOlm) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`OLM with ID ${olmId} does not exist`
)
);
}
if (existingOlm.userId !== userId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`OLM with ID ${olmId} does not belong to user with ID ${userId}`
)
);
}
let newClient: Client | null = null;
await db.transaction(async (trx) => {
// TODO: more intelligent way to pick the exit node
const exitNodesList = await listExitNodes(orgId);
const randomExitNode =
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
const [adminRole] = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (!adminRole) {
return next(
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
);
}
[newClient] = await trx
.insert(clients)
.values({
exitNodeId: randomExitNode.exitNodeId,
orgId,
name,
subnet: updatedSubnet,
type,
olmId, // this is to lock it to a specific olm even if the olm moves across clients
userId
})
.returning();
await trx.insert(roleClients).values({
roleId: adminRole.roleId,
clientId: newClient.clientId
});
trx.insert(userClients).values({
userId,
clientId: newClient.clientId
});
await rebuildClientAssociationsFromClient(newClient, trx);
});
return response<CreateClientAndOlmResponse>(res, {
data: newClient,
success: true,
error: false,
message: "Site created successfully",
status: HttpCode.CREATED
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients, clientSites } from "@server/db";
import { db, olms } from "@server/db";
import { clients, clientSitesAssociationsCache } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
@@ -9,10 +9,12 @@ import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { sendTerminateClient } from "./terminate";
const deleteClientSchema = z.strictObject({
clientId: z.string().transform(Number).pipe(z.int().positive())
});
clientId: z.string().transform(Number).pipe(z.int().positive())
});
registry.registerPath({
method: "delete",
@@ -58,16 +60,38 @@ export async function deleteClient(
);
}
await db.transaction(async (trx) => {
// Delete the client-site associations first
await trx
.delete(clientSites)
.where(eq(clientSites.clientId, clientId));
if (client.userId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Cannot delete a user client with this endpoint`
)
);
}
await db.transaction(async (trx) => {
// Then delete the client itself
await trx
const [deletedClient] = await trx
.delete(clients)
.where(eq(clients.clientId, clientId));
.where(eq(clients.clientId, clientId))
.returning();
const [olm] = await trx
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
// this is a machine client so we also delete the olm
if (!client.userId && client.olmId) {
await trx.delete(olms).where(eq(olms.olmId, client.olmId));
}
await rebuildClientAssociationsFromClient(deletedClient, trx);
if (olm) {
await sendTerminateClient(deletedClient.clientId, olm.olmId); // the olmId needs to be provided because it cant look it up after deletion
}
});
return response(res, {

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients, clientSites } from "@server/db";
import { clients, clientSitesAssociationsCache } from "@server/db";
import { eq, and } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
@@ -29,9 +29,9 @@ async function query(clientId: number) {
// Get the siteIds associated with this client
const sites = await db
.select({ siteId: clientSites.siteId })
.from(clientSites)
.where(eq(clientSites.clientId, clientId));
.select({ siteId: clientSitesAssociationsCache.siteId })
.from(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.clientId, clientId));
// Add the siteIds to the client object
return {

View File

@@ -3,4 +3,5 @@ export * from "./createClient";
export * from "./deleteClient";
export * from "./listClients";
export * from "./updateClient";
export * from "./getClient";
export * from "./getClient";
export * from "./createUserClient";

View File

@@ -1,16 +1,16 @@
import { db, olms } from "@server/db";
import { db, olms, users } from "@server/db";
import {
clients,
orgs,
roleClients,
sites,
userClients,
clientSites
clientSitesAssociationsCache
} from "@server/db";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
import { and, count, eq, inArray, or, sql } from "drizzle-orm";
import { and, count, eq, inArray, isNotNull, isNull, or, sql } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
@@ -19,7 +19,7 @@ import { OpenAPITags, registry } from "@server/openApi";
import NodeCache from "node-cache";
import semver from "semver";
const olmVersionCache = new NodeCache({ stdTTL: 3600 });
const olmVersionCache = new NodeCache({ stdTTL: 3600 });
async function getLatestOlmVersion(): Promise<string | null> {
try {
@@ -29,7 +29,7 @@ async function getLatestOlmVersion(): Promise<string | null> {
}
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 1500);
const timeoutId = setTimeout(() => controller.abort(), 1500);
const response = await fetch(
"https://api.github.com/repos/fosrl/olm/tags",
@@ -94,10 +94,25 @@ const listClientsSchema = z.object({
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
.pipe(z.int().nonnegative()),
filter: z
.enum(["user", "machine"])
.optional()
});
function queryClients(orgId: string, accessibleClientIds: number[]) {
function queryClients(orgId: string, accessibleClientIds: number[], filter?: "user" | "machine") {
const conditions = [
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId)
];
// Add filter condition based on filter type
if (filter === "user") {
conditions.push(isNotNull(clients.userId));
} else if (filter === "machine") {
conditions.push(isNull(clients.userId));
}
return db
.select({
clientId: clients.clientId,
@@ -110,17 +125,16 @@ function queryClients(orgId: string, accessibleClientIds: number[]) {
orgName: orgs.name,
type: clients.type,
online: clients.online,
olmVersion: olms.version
olmVersion: olms.version,
userId: clients.userId,
username: users.username,
userEmail: users.email
})
.from(clients)
.leftJoin(orgs, eq(clients.orgId, orgs.orgId))
.leftJoin(olms, eq(clients.clientId, olms.clientId))
.where(
and(
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId)
)
);
.leftJoin(users, eq(clients.userId, users.userId))
.where(and(...conditions));
}
async function getSiteAssociations(clientIds: number[]) {
@@ -128,14 +142,14 @@ async function getSiteAssociations(clientIds: number[]) {
return db
.select({
clientId: clientSites.clientId,
siteId: clientSites.siteId,
clientId: clientSitesAssociationsCache.clientId,
siteId: clientSitesAssociationsCache.siteId,
siteName: sites.name,
siteNiceId: sites.niceId
})
.from(clientSites)
.leftJoin(sites, eq(clientSites.siteId, sites.siteId))
.where(inArray(clientSites.clientId, clientIds));
.from(clientSitesAssociationsCache)
.leftJoin(sites, eq(clientSitesAssociationsCache.siteId, sites.siteId))
.where(inArray(clientSitesAssociationsCache.clientId, clientIds));
}
type OlmWithUpdateAvailable = Awaited<ReturnType<typeof queryClients>>[0] & {
@@ -182,7 +196,7 @@ export async function listClients(
)
);
}
const { limit, offset } = parsedQuery.data;
const { limit, offset, filter } = parsedQuery.data;
const parsedParams = listClientsParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
@@ -231,18 +245,24 @@ export async function listClients(
const accessibleClientIds = accessibleClients.map(
(client) => client.clientId
);
const baseQuery = queryClients(orgId, accessibleClientIds);
const baseQuery = queryClients(orgId, accessibleClientIds, filter);
// Get client count with filter
const countConditions = [
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId)
];
if (filter === "user") {
countConditions.push(isNotNull(clients.userId));
} else if (filter === "machine") {
countConditions.push(isNull(clients.userId));
}
// Get client count
const countQuery = db
.select({ count: count() })
.from(clients)
.where(
and(
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId)
)
);
.where(and(...countConditions));
const clientsList = await baseQuery.limit(limit).offset(offset);
const totalCountResult = await countQuery;

View File

@@ -1,35 +1,136 @@
import { sendToClient } from "#dynamic/routers/ws";
import { db, olms } from "@server/db";
import { Alias, SubnetProxyTarget } from "@server/lib/ip";
import logger from "@server/logger";
import { eq } from "drizzle-orm";
export async function addTargets(
newtId: string,
destinationIp: string,
destinationPort: number,
protocol: string,
port: number
) {
const target = `${port}:${destinationIp}:${destinationPort}`;
export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
await sendToClient(newtId, {
type: `newt/wg/${protocol}/add`,
data: {
targets: [target] // We can only use one target for WireGuard right now
}
type: `newt/wg/targets/add`,
data: targets
});
}
export async function removeTargets(
newtId: string,
destinationIp: string,
destinationPort: number,
protocol: string,
port: number
targets: SubnetProxyTarget[]
) {
const target = `${port}:${destinationIp}:${destinationPort}`;
await sendToClient(newtId, {
type: `newt/wg/${protocol}/remove`,
data: {
targets: [target] // We can only use one target for WireGuard right now
}
type: `newt/wg/targets/remove`,
data: targets
});
}
export async function updateTargets(
newtId: string,
targets: {
oldTargets: SubnetProxyTarget[];
newTargets: SubnetProxyTarget[];
}
) {
await sendToClient(newtId, {
type: `newt/wg/targets/update`,
data: targets
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
export async function addPeerData(
clientId: number,
siteId: number,
remoteSubnets: string[],
aliases: Alias[],
olmId?: string
) {
if (!olmId) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
return; // ignore this because an olm might not be associated with the client anymore
}
olmId = olm.olmId;
}
await sendToClient(olmId, {
type: `olm/wg/peer/data/add`,
data: {
siteId: siteId,
remoteSubnets: remoteSubnets,
aliases: aliases
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
export async function removePeerData(
clientId: number,
siteId: number,
remoteSubnets: string[],
aliases: Alias[],
olmId?: string
) {
if (!olmId) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
return;
}
olmId = olm.olmId;
}
await sendToClient(olmId, {
type: `olm/wg/peer/data/remove`,
data: {
siteId: siteId,
remoteSubnets: remoteSubnets,
aliases: aliases
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
export async function updatePeerData(
clientId: number,
siteId: number,
remoteSubnets: {
oldRemoteSubnets: string[];
newRemoteSubnets: string[];
} | undefined,
aliases: {
oldAliases: Alias[];
newAliases: Alias[];
} | undefined,
olmId?: string
) {
if (!olmId) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
return;
}
olmId = olm.olmId;
}
await sendToClient(olmId, {
type: `olm/wg/peer/data/update`,
data: {
siteId: siteId,
...remoteSubnets,
...aliases
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}

View File

@@ -0,0 +1,22 @@
import { sendToClient } from "#dynamic/routers/ws";
import { db, olms } from "@server/db";
import { eq } from "drizzle-orm";
export async function sendTerminateClient(clientId: number, olmId?: string | null) {
if (!olmId) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
olmId = olm.olmId;
}
await sendToClient(olmId, {
type: `olm/terminate`,
data: {}
});
}

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { Client, db, exitNodes, olms, sites } from "@server/db";
import { clients, clientSites } from "@server/db";
import { clients, clientSitesAssociationsCache } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -9,27 +9,14 @@ import logger from "@server/logger";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import {
addPeer as newtAddPeer,
deletePeer as newtDeletePeer
} from "../newt/peers";
import {
addPeer as olmAddPeer,
deletePeer as olmDeletePeer
} from "../olm/peers";
import { sendToExitNode } from "#dynamic/lib/exitNodes";
import { hashPassword } from "@server/auth/password";
const updateClientParamsSchema = z.strictObject({
clientId: z.string().transform(Number).pipe(z.int().positive())
});
clientId: z.string().transform(Number).pipe(z.int().positive())
});
const updateClientSchema = z.strictObject({
name: z.string().min(1).max(255).optional(),
siteIds: z
.array(z.int().positive())
.optional(),
});
name: z.string().min(1).max(255).optional()
});
export type UpdateClientBody = z.infer<typeof updateClientSchema>;
@@ -51,11 +38,6 @@ registry.registerPath({
responses: {}
});
interface PeerDestination {
destinationIP: string;
destinationPort: number;
}
export async function updateClient(
req: Request,
res: Response,
@@ -72,7 +54,7 @@ export async function updateClient(
);
}
const { name, siteIds } = parsedBody.data;
const { name } = parsedBody.data;
const parsedParams = updateClientParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
@@ -86,7 +68,6 @@ export async function updateClient(
const { clientId } = parsedParams.data;
// Fetch the client to make sure it exists and the user has access to it
const [client] = await db
.select()
@@ -103,266 +84,11 @@ export async function updateClient(
);
}
let sitesAdded = [];
let sitesRemoved = [];
// Fetch existing site associations
const existingSites = await db
.select({ siteId: clientSites.siteId })
.from(clientSites)
.where(eq(clientSites.clientId, clientId));
const existingSiteIds = existingSites.map((site) => site.siteId);
const siteIdsToProcess = siteIds || [];
// Determine which sites were added and removed
sitesAdded = siteIdsToProcess.filter(
(siteId) => !existingSiteIds.includes(siteId)
);
sitesRemoved = existingSiteIds.filter(
(siteId) => !siteIdsToProcess.includes(siteId)
);
let updatedClient: Client | undefined = undefined;
let sitesData: any; // TODO: define type somehow from the query below
await db.transaction(async (trx) => {
// Update client name if provided
if (name) {
await trx
.update(clients)
.set({ name })
.where(eq(clients.clientId, clientId));
}
// Update site associations if provided
// Remove sites that are no longer associated
for (const siteId of sitesRemoved) {
await trx
.delete(clientSites)
.where(
and(
eq(clientSites.clientId, clientId),
eq(clientSites.siteId, siteId)
)
);
}
// Add new site associations
for (const siteId of sitesAdded) {
await trx.insert(clientSites).values({
clientId,
siteId
});
}
// Fetch the updated client
[updatedClient] = await trx
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
// get all sites for this client and join with exit nodes with site.exitNodeId
sitesData = await trx
.select()
.from(sites)
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
.leftJoin(exitNodes, eq(sites.exitNodeId, exitNodes.exitNodeId))
.where(eq(clientSites.clientId, client.clientId));
});
logger.info(
`Adding ${sitesAdded.length} new sites to client ${client.clientId}`
);
for (const siteId of sitesAdded) {
if (!client.subnet || !client.pubKey) {
logger.debug("Client subnet, pubKey or endpoint is not set");
continue;
}
// TODO: WE NEED TO HANDLE THIS BETTER. WE ARE DEFAULTING TO RELAYING FOR NEW SITES
// BUT REALLY WE NEED TO TRACK THE USERS PREFERENCE THAT THEY CHOSE IN THE CLIENTS
// AND TRIGGER A HOLEPUNCH OR SOMETHING TO GET THE ENDPOINT AND HP TO THE NEW SITES
const isRelayed = true;
const site = await newtAddPeer(siteId, {
publicKey: client.pubKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
// endpoint: isRelayed ? "" : clientSite.endpoint
endpoint: isRelayed ? "" : "" // we are not HPing yet so no endpoint
});
if (!site) {
logger.debug("Failed to add peer to newt - missing site");
continue;
}
if (!site.endpoint || !site.publicKey) {
logger.debug("Site endpoint or publicKey is not set");
continue;
}
let endpoint;
if (isRelayed) {
if (!site.exitNodeId) {
logger.warn(
`Site ${site.siteId} has no exit node, skipping`
);
return null;
}
// get the exit node for the site
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
if (!exitNode) {
logger.warn(`Exit node not found for site ${site.siteId}`);
return null;
}
endpoint = `${exitNode.endpoint}:21820`;
} else {
if (!site.endpoint) {
logger.warn(
`Site ${site.siteId} has no endpoint, skipping`
);
return null;
}
endpoint = site.endpoint;
}
await olmAddPeer(client.clientId, {
siteId: site.siteId,
endpoint: endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: site.remoteSubnets
});
}
logger.info(
`Removing ${sitesRemoved.length} sites from client ${client.clientId}`
);
for (const siteId of sitesRemoved) {
if (!client.pubKey) {
logger.debug("Client pubKey is not set");
continue;
}
const site = await newtDeletePeer(siteId, client.pubKey);
if (!site) {
logger.debug("Failed to delete peer from newt - missing site");
continue;
}
if (!site.endpoint || !site.publicKey) {
logger.debug("Site endpoint or publicKey is not set");
continue;
}
await olmDeletePeer(client.clientId, site.siteId, site.publicKey);
}
if (!updatedClient || !sitesData) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Failed to update client`
)
);
}
let exitNodeDestinations: {
reachableAt: string;
exitNodeId: number;
type: string;
name: string;
sourceIp: string;
sourcePort: number;
destinations: PeerDestination[];
}[] = [];
for (const site of sitesData) {
if (!site.sites.subnet) {
logger.warn(
`Site ${site.sites.siteId} has no subnet, skipping`
);
continue;
}
if (!site.clientSites.endpoint) {
logger.warn(
`Site ${site.sites.siteId} has no endpoint, skipping`
);
continue;
}
// find the destinations in the array
let destinations = exitNodeDestinations.find(
(d) => d.reachableAt === site.exitNodes?.reachableAt
);
if (!destinations) {
destinations = {
reachableAt: site.exitNodes?.reachableAt || "",
exitNodeId: site.exitNodes?.exitNodeId || 0,
type: site.exitNodes?.type || "",
name: site.exitNodes?.name || "",
sourceIp: site.clientSites.endpoint.split(":")[0] || "",
sourcePort:
parseInt(site.clientSites.endpoint.split(":")[1]) || 0,
destinations: [
{
destinationIP: site.sites.subnet.split("/")[0],
destinationPort: site.sites.listenPort || 0
}
]
};
} else {
// add to the existing destinations
destinations.destinations.push({
destinationIP: site.sites.subnet.split("/")[0],
destinationPort: site.sites.listenPort || 0
});
}
// update it in the array
exitNodeDestinations = exitNodeDestinations.filter(
(d) => d.reachableAt !== site.exitNodes?.reachableAt
);
exitNodeDestinations.push(destinations);
}
for (const destination of exitNodeDestinations) {
logger.info(
`Updating destinations for exit node at ${destination.reachableAt}`
);
const payload = {
sourceIp: destination.sourceIp,
sourcePort: destination.sourcePort,
destinations: destination.destinations
};
logger.info(
`Payload for update-destinations: ${JSON.stringify(payload, null, 2)}`
);
// Create an ExitNode-like object for sendToExitNode
const exitNodeForComm = {
exitNodeId: destination.exitNodeId,
type: destination.type,
reachableAt: destination.reachableAt,
name: destination.name
} as any; // Using 'as any' since we know sendToExitNode will handle this correctly
await sendToExitNode(exitNodeForComm, {
remoteType: "remoteExitNode/update-destinations",
localPath: "/update-destinations",
method: "POST",
data: payload
});
}
const updatedClient = await db
.update(clients)
.set({ name })
.where(eq(clients.clientId, clientId))
.returning();
return response(res, {
data: updatedClient,

View File

@@ -16,6 +16,8 @@ import * as idp from "./idp";
import * as blueprints from "./blueprints";
import * as apiKeys from "./apiKeys";
import * as logs from "./auditLogs";
import * as newt from "./newt";
import * as olm from "./olm";
import HttpCode from "@server/types/HttpCode";
import {
verifyAccessTokenAccess,
@@ -27,6 +29,7 @@ import {
verifyTargetAccess,
verifyRoleAccess,
verifySetResourceUsers,
verifySetResourceClients,
verifyUserAccess,
getUserOrgs,
verifyUserIsServerAdmin,
@@ -34,14 +37,12 @@ import {
verifyClientAccess,
verifyApiKeyAccess,
verifyDomainAccess,
verifyClientsEnabled,
verifyUserHasAction,
verifyUserIsOrgOwner,
verifySiteResourceAccess
verifySiteResourceAccess,
verifyOlmAccess
} from "@server/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import { createNewt, getNewtToken } from "./newt";
import { getOlmToken } from "./olm";
import rateLimit, { ipKeyGenerator } from "express-rate-limit";
import createHttpError from "http-errors";
import { build } from "@server/build";
@@ -129,7 +130,6 @@ authenticated.get(
authenticated.get(
"/org/:orgId/pick-client-defaults",
verifyClientsEnabled,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createClient),
client.pickClientDefaults
@@ -137,7 +137,6 @@ authenticated.get(
authenticated.get(
"/org/:orgId/clients",
verifyClientsEnabled,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.listClients),
client.listClients
@@ -145,7 +144,6 @@ authenticated.get(
authenticated.get(
"/client/:clientId",
verifyClientsEnabled,
verifyClientAccess,
verifyUserHasAction(ActionsEnum.getClient),
client.getClient
@@ -153,16 +151,15 @@ authenticated.get(
authenticated.put(
"/org/:orgId/client",
verifyClientsEnabled,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createClient),
logActionAudit(ActionsEnum.createClient),
client.createClient
);
// TODO: Separate into a deleteUserClient (for user clients) and deleteClient (for machine clients)
authenticated.delete(
"/client/:clientId",
verifyClientsEnabled,
verifyClientAccess,
verifyUserHasAction(ActionsEnum.deleteClient),
logActionAudit(ActionsEnum.deleteClient),
@@ -171,7 +168,6 @@ authenticated.delete(
authenticated.post(
"/client/:clientId",
verifyClientsEnabled,
verifyClientAccess, // this will check if the user has access to the client
verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client
logActionAudit(ActionsEnum.updateClient),
@@ -286,6 +282,72 @@ authenticated.delete(
siteResource.deleteSiteResource
);
authenticated.get(
"/site-resource/:siteResourceId/roles",
verifySiteResourceAccess,
verifyUserHasAction(ActionsEnum.listResourceRoles),
siteResource.listSiteResourceRoles
);
authenticated.get(
"/site-resource/:siteResourceId/users",
verifySiteResourceAccess,
verifyUserHasAction(ActionsEnum.listResourceUsers),
siteResource.listSiteResourceUsers
);
authenticated.get(
"/site-resource/:siteResourceId/clients",
verifySiteResourceAccess,
verifyUserHasAction(ActionsEnum.listResourceUsers),
siteResource.listSiteResourceClients
);
authenticated.post(
"/site-resource/:siteResourceId/roles",
verifySiteResourceAccess,
verifyRoleAccess,
verifyUserHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles),
siteResource.setSiteResourceRoles,
);
authenticated.post(
"/site-resource/:siteResourceId/users",
verifySiteResourceAccess,
verifySetResourceUsers,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceUsers,
);
authenticated.post(
"/site-resource/:siteResourceId/clients",
verifySiteResourceAccess,
verifySetResourceClients,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceClients,
);
authenticated.post(
"/site-resource/:siteResourceId/clients/add",
verifySiteResourceAccess,
verifySetResourceClients,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.addClientToSiteResource,
);
authenticated.post(
"/site-resource/:siteResourceId/clients/remove",
verifySiteResourceAccess,
verifySetResourceClients,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.removeClientFromSiteResource,
);
authenticated.put(
"/org/:orgId/resource",
verifyOrgAccess,
@@ -649,9 +711,15 @@ unauthenticated.get(
// );
unauthenticated.get("/user", verifySessionMiddleware, user.getUser);
unauthenticated.get("/my-device", verifySessionMiddleware, user.myDevice);
authenticated.get("/users", verifyUserIsServerAdmin, user.adminListUsers);
authenticated.get("/user/:userId", verifyUserIsServerAdmin, user.adminGetUser);
authenticated.post(
"/user/:userId/generate-password-reset-code",
verifyUserIsServerAdmin,
user.adminGeneratePasswordResetCode
);
authenticated.delete(
"/user/:userId",
verifyUserIsServerAdmin,
@@ -734,6 +802,32 @@ authenticated.delete(
// createNewt
// );
authenticated.put(
"/user/:userId/olm",
verifyIsLoggedInUser,
olm.createUserOlm
);
authenticated.get(
"/user/:userId/olms",
verifyIsLoggedInUser,
olm.listUserOlms
);
authenticated.delete(
"/user/:userId/olm/:olmId",
verifyIsLoggedInUser,
verifyOlmAccess,
olm.deleteUserOlm
);
authenticated.get(
"/user/:userId/olm/:olmId",
verifyIsLoggedInUser,
verifyOlmAccess,
olm.getUserOlm
);
authenticated.put(
"/idp/oidc",
verifyUserIsServerAdmin,
@@ -993,7 +1087,7 @@ authRouter.post(
},
store: createStore()
}),
getNewtToken
newt.getNewtToken
);
authRouter.post(
"/olm/get-token",
@@ -1008,7 +1102,7 @@ authRouter.post(
},
store: createStore()
}),
getOlmToken
olm.getOlmToken
);
authRouter.post(
@@ -1253,3 +1347,51 @@ authRouter.delete(
}),
auth.deleteSecurityKey
);
authRouter.post(
"/device-web-auth/start",
rateLimit({
windowMs: 15 * 60 * 1000, // 15 minutes
max: 30, // Allow 30 device auth code requests per 15 minutes per IP
keyGenerator: (req) =>
`deviceWebAuthStart:${ipKeyGenerator(req.ip || "")}`,
handler: (req, res, next) => {
const message = `You can only request a device auth code ${30} times every ${15} minutes. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
},
store: createStore()
}),
auth.startDeviceWebAuth
);
authRouter.get(
"/device-web-auth/poll/:code",
rateLimit({
windowMs: 60 * 1000, // 1 minute
max: 60, // Allow 60 polling requests per minute per IP (poll every second)
keyGenerator: (req) =>
`deviceWebAuthPoll:${ipKeyGenerator(req.ip || "")}:${req.params.code}`,
handler: (req, res, next) => {
const message = `You can only poll a device auth code ${60} times per minute. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
},
store: createStore()
}),
auth.pollDeviceWebAuth
);
authenticated.post(
"/device-web-auth/verify",
rateLimit({
windowMs: 15 * 60 * 1000, // 15 minutes
max: 50, // Allow 50 verification attempts per 15 minutes per user
keyGenerator: (req) =>
`deviceWebAuthVerify:${req.user?.userId || ipKeyGenerator(req.ip || "")}`,
handler: (req, res, next) => {
const message = `You can only verify a device auth code ${50} times every ${15} minutes. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
},
store: createStore()
}),
auth.verifyDeviceWebAuth
);

View File

@@ -7,7 +7,7 @@ import {
olms,
Site,
sites,
clientSites,
clientSitesAssociationsCache,
ExitNode
} from "@server/db";
import { db } from "@server/db";
@@ -109,8 +109,8 @@ export async function generateRelayMappings(exitNode: ExitNode) {
// Find all clients associated with this site through clientSites
const clientSitesRes = await db
.select()
.from(clientSites)
.where(eq(clientSites.siteId, site.siteId));
.from(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.siteId, site.siteId));
for (const clientSite of clientSitesRes) {
if (!clientSite.endpoint) {

View File

@@ -6,7 +6,7 @@ import {
olms,
Site,
sites,
clientSites,
clientSitesAssociationsCache,
exitNodes,
ExitNode
} from "@server/db";
@@ -19,6 +19,8 @@ import { fromError } from "zod-validation-error";
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
import { checkExitNodeOrg } from "#dynamic/lib/exitNodes";
import { updatePeer as updateOlmPeer } from "../olm/peers";
import { updatePeer as updateNewtPeer } from "../newt/peers";
// Define Zod schema for request validation
const updateHolePunchSchema = z.object({
@@ -28,8 +30,9 @@ const updateHolePunchSchema = z.object({
ip: z.string(),
port: z.number(),
timestamp: z.number(),
publicKey: z.string(),
reachableAt: z.string().optional(),
publicKey: z.string().optional()
exitNodePublicKey: z.string().optional()
});
// New response type with multi-peer destination support
@@ -63,23 +66,26 @@ export async function updateHolePunch(
timestamp,
token,
reachableAt,
publicKey
publicKey, // this is the client's current public key for this session
exitNodePublicKey
} = parsedParams.data;
let exitNode: ExitNode | undefined;
if (publicKey) {
if (exitNodePublicKey) {
// Get the exit node by public key
[exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.publicKey, publicKey));
.where(eq(exitNodes.publicKey, exitNodePublicKey));
} else {
// FOR BACKWARDS COMPATIBILITY IF GERBIL IS STILL =<1.1.0
[exitNode] = await db.select().from(exitNodes).limit(1);
}
if (!exitNode) {
logger.warn(`Exit node not found for publicKey: ${publicKey}`);
logger.warn(
`Exit node not found for publicKey: ${exitNodePublicKey}`
);
return next(
createHttpError(HttpCode.NOT_FOUND, "Exit node not found")
);
@@ -92,12 +98,13 @@ export async function updateHolePunch(
port,
timestamp,
token,
publicKey,
exitNode
);
logger.debug(
`Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}`
);
// logger.debug(
// `Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}`
// );
// Return the new multi-peer structure
return res.status(HttpCode.OK).send({
@@ -121,6 +128,7 @@ export async function updateAndGenerateEndpointDestinations(
port: number,
timestamp: number,
token: string,
publicKey: string,
exitNode: ExitNode,
checkOrg = false
) {
@@ -128,9 +136,9 @@ export async function updateAndGenerateEndpointDestinations(
const destinations: PeerDestination[] = [];
if (olmId) {
logger.debug(
`Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}`
);
// logger.debug(
// `Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}`
// );
const { session, olm: olmSession } =
await validateOlmSessionToken(token);
@@ -150,7 +158,7 @@ export async function updateAndGenerateEndpointDestinations(
throw new Error("Olm not found");
}
const [client] = await db
const [updatedClient] = await db
.update(clients)
.set({
lastHolePunch: timestamp
@@ -158,10 +166,16 @@ export async function updateAndGenerateEndpointDestinations(
.where(eq(clients.clientId, olm.clientId))
.returning();
if (await checkExitNodeOrg(exitNode.exitNodeId, client.orgId) && checkOrg) {
if (
(await checkExitNodeOrg(
exitNode.exitNodeId,
updatedClient.orgId
)) &&
checkOrg
) {
// not allowed
logger.warn(
`Exit node ${exitNode.exitNodeId} is not allowed for org ${client.orgId}`
`Exit node ${exitNode.exitNodeId} is not allowed for org ${updatedClient.orgId}`
);
throw new Error("Exit node not allowed");
}
@@ -171,40 +185,70 @@ export async function updateAndGenerateEndpointDestinations(
.select({
siteId: sites.siteId,
subnet: sites.subnet,
listenPort: sites.listenPort
listenPort: sites.listenPort,
publicKey: sites.publicKey,
endpoint: clientSitesAssociationsCache.endpoint
})
.from(sites)
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(
and(
eq(sites.exitNodeId, exitNode.exitNodeId),
eq(clientSites.clientId, olm.clientId)
eq(clientSitesAssociationsCache.clientId, olm.clientId)
)
);
// Update clientSites for each site on this exit node
for (const site of sitesOnExitNode) {
logger.debug(
`Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}`
);
// logger.debug(
// `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}`
// );
await db
.update(clientSites)
// if the public key or endpoint has changed, update it otherwise continue
if (
site.endpoint === `${ip}:${port}` &&
site.publicKey === publicKey
) {
continue;
}
const [updatedClientSitesAssociationsCache] = await db
.update(clientSitesAssociationsCache)
.set({
endpoint: `${ip}:${port}`
endpoint: `${ip}:${port}`,
publicKey: publicKey
})
.where(
and(
eq(clientSites.clientId, olm.clientId),
eq(clientSites.siteId, site.siteId)
eq(clientSitesAssociationsCache.clientId, olm.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
)
.returning();
if (
updatedClientSitesAssociationsCache.endpoint !==
site.endpoint && // this is the endpoint from the join table not the site
updatedClient.pubKey === publicKey // only trigger if the client's public key matches the current public key which means it has registered so we dont prematurely send the update
) {
logger.info(
`ClientSitesAssociationsCache for client ${olm.clientId} and site ${site.siteId} endpoint changed from ${site.endpoint} to ${updatedClientSitesAssociationsCache.endpoint}`
);
// Handle any additional logic for endpoint change
handleClientEndpointChange(
olm.clientId,
updatedClientSitesAssociationsCache.endpoint!
);
}
}
logger.debug(
`Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}`
);
if (!client) {
// logger.debug(
// `Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}`
// );
if (!updatedClient) {
logger.warn(`Client not found for olm: ${olmId}`);
throw new Error("Client not found");
}
@@ -219,9 +263,9 @@ export async function updateAndGenerateEndpointDestinations(
}
}
} else if (newtId) {
logger.debug(
`Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}`
);
// logger.debug(
// `Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}`
// );
const { session, newt: newtSession } =
await validateNewtSessionToken(token);
@@ -253,7 +297,10 @@ export async function updateAndGenerateEndpointDestinations(
.where(eq(sites.siteId, newt.siteId))
.limit(1);
if (await checkExitNodeOrg(exitNode.exitNodeId, site.orgId) && checkOrg) {
if (
(await checkExitNodeOrg(exitNode.exitNodeId, site.orgId)) &&
checkOrg
) {
// not allowed
logger.warn(
`Exit node ${exitNode.exitNodeId} is not allowed for org ${site.orgId}`
@@ -273,6 +320,18 @@ export async function updateAndGenerateEndpointDestinations(
.where(eq(sites.siteId, newt.siteId))
.returning();
if (
updatedSite.endpoint != site.endpoint &&
updatedSite.publicKey == publicKey
) {
// only trigger if the site's public key matches the current public key which means it has registered so we dont prematurely send the update
logger.info(
`Site ${newt.siteId} endpoint changed from ${site.endpoint} to ${updatedSite.endpoint}`
);
// Handle any additional logic for endpoint change
handleSiteEndpointChange(newt.siteId, updatedSite.endpoint!);
}
if (!updatedSite || !updatedSite.subnet) {
logger.warn(`Site not found: ${newt.siteId}`);
throw new Error("Site not found");
@@ -326,3 +385,143 @@ export async function updateAndGenerateEndpointDestinations(
}
return destinations;
}
async function handleSiteEndpointChange(siteId: number, newEndpoint: string) {
// Alert all clients connected to this site that the endpoint has changed (only if NOT relayed)
try {
// Get site details
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site || !site.publicKey) {
logger.warn(`Site ${siteId} not found or has no public key`);
return;
}
// Get all non-relayed clients connected to this site
const connectedClients = await db
.select({
clientId: clients.clientId,
olmId: olms.olmId,
isRelayed: clientSitesAssociationsCache.isRelayed
})
.from(clientSitesAssociationsCache)
.innerJoin(
clients,
eq(clientSitesAssociationsCache.clientId, clients.clientId)
)
.innerJoin(olms, eq(olms.clientId, clients.clientId))
.where(
and(
eq(clientSitesAssociationsCache.siteId, siteId),
eq(clientSitesAssociationsCache.isRelayed, false)
)
);
// Update each non-relayed client with the new site endpoint
for (const client of connectedClients) {
try {
await updateOlmPeer(
client.clientId,
{
siteId: siteId,
publicKey: site.publicKey,
endpoint: newEndpoint
},
client.olmId
);
logger.debug(
`Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}`
);
} catch (error) {
logger.error(
`Failed to update client ${client.clientId} with new site endpoint: ${error}`
);
}
}
} catch (error) {
logger.error(
`Error handling site endpoint change for site ${siteId}: ${error}`
);
}
}
async function handleClientEndpointChange(
clientId: number,
newEndpoint: string
) {
// Alert all sites connected to this client that the endpoint has changed (only if NOT relayed)
try {
// Get client details
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client || !client.pubKey) {
logger.warn(`Client ${clientId} not found or has no public key`);
return;
}
// Get all non-relayed sites connected to this client
const connectedSites = await db
.select({
siteId: sites.siteId,
newtId: newts.newtId,
isRelayed: clientSitesAssociationsCache.isRelayed,
subnet: clients.subnet
})
.from(clientSitesAssociationsCache)
.innerJoin(
sites,
eq(clientSitesAssociationsCache.siteId, sites.siteId)
)
.innerJoin(newts, eq(newts.siteId, sites.siteId))
.innerJoin(
clients,
eq(clientSitesAssociationsCache.clientId, clients.clientId)
)
.where(
and(
eq(clientSitesAssociationsCache.clientId, clientId),
eq(clientSitesAssociationsCache.isRelayed, false)
)
);
// Update each non-relayed site with the new client endpoint
for (const siteData of connectedSites) {
try {
if (!siteData.subnet) {
logger.warn(
`Client ${clientId} has no subnet, skipping update for site ${siteData.siteId}`
);
continue;
}
await updateNewtPeer(
siteData.siteId,
client.pubKey,
{
endpoint: newEndpoint
},
siteData.newtId
);
logger.debug(
`Updated site ${siteData.siteId} with new client ${clientId} endpoint: ${newEndpoint}`
);
} catch (error) {
logger.error(
`Failed to update site ${siteData.siteId} with new client endpoint: ${error}`
);
}
}
} catch (error) {
logger.error(
`Error handling client endpoint change for client ${clientId}: ${error}`
);
}
}

View File

@@ -33,6 +33,7 @@ import { UserType } from "@server/types/UserTypes";
import { FeatureId } from "@server/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
const ensureTrailingSlash = (url: string): string => {
return url;
@@ -364,10 +365,18 @@ export async function validateOidcCallback(
);
if (!existingUserOrgs.length) {
// delete the user
// await db
// .delete(users)
// .where(eq(users.userId, existingUser.userId));
// delete all auto -provisioned user orgs
await db
.delete(userOrgs)
.where(
and(
eq(userOrgs.userId, existingUser.userId),
eq(userOrgs.autoProvisioned, true)
)
);
await calculateUserClientsForOrgs(existingUser.userId);
return next(
createHttpError(
HttpCode.UNAUTHORIZED,
@@ -513,6 +522,8 @@ export async function validateOidcCallback(
userCount: userCount.length
});
}
await calculateUserClientsForOrgs(userId!, trx);
});
for (const orgCount of orgUserCounts) {
@@ -553,6 +564,24 @@ export async function validateOidcCallback(
);
}
// check for existing user orgs
const existingUserOrgs = await db
.select()
.from(userOrgs)
.where(and(eq(userOrgs.userId, existingUser.userId)));
if (!existingUserOrgs.length) {
logger.debug(
"No existing user orgs found for non-auto-provisioned IdP"
);
return next(
createHttpError(
HttpCode.UNAUTHORIZED,
`User with username ${userIdentifier} is unprovisioned. This user must be added to an organization before logging in.`
)
);
}
const token = generateSessionToken();
const sess = await createSession(token, existingUser.userId);
const isSecure = req.protocol === "https";

View File

@@ -10,6 +10,7 @@ import * as client from "./client";
import * as accessToken from "./accessToken";
import * as apiKeys from "./apiKeys";
import * as idp from "./idp";
import * as logs from "./auditLogs";
import * as siteResource from "./siteResource";
import {
verifyApiKey,
@@ -24,8 +25,8 @@ import {
verifyApiKeyAccessTokenAccess,
verifyApiKeyIsRoot,
verifyApiKeyClientAccess,
verifyClientsEnabled,
verifyApiKeySiteResourceAccess
verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceClients
} from "@server/middlewares";
import HttpCode from "@server/types/HttpCode";
import { Router } from "express";
@@ -197,6 +198,108 @@ authenticated.delete(
siteResource.deleteSiteResource
);
authenticated.get(
"/site-resource/:siteResourceId/roles",
verifyApiKeySiteResourceAccess,
verifyApiKeyHasAction(ActionsEnum.listResourceRoles),
siteResource.listSiteResourceRoles
);
authenticated.get(
"/site-resource/:siteResourceId/users",
verifyApiKeySiteResourceAccess,
verifyApiKeyHasAction(ActionsEnum.listResourceUsers),
siteResource.listSiteResourceUsers
);
authenticated.get(
"/site-resource/:siteResourceId/clients",
verifyApiKeySiteResourceAccess,
verifyApiKeyHasAction(ActionsEnum.listResourceUsers),
siteResource.listSiteResourceClients
);
authenticated.post(
"/site-resource/:siteResourceId/roles",
verifyApiKeySiteResourceAccess,
verifyApiKeyRoleAccess,
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles),
siteResource.setSiteResourceRoles
);
authenticated.post(
"/site-resource/:siteResourceId/users",
verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceUsers,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceUsers
);
authenticated.post(
"/site-resource/:siteResourceId/roles/add",
verifyApiKeySiteResourceAccess,
verifyApiKeyRoleAccess,
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles),
siteResource.addRoleToSiteResource
);
authenticated.post(
"/site-resource/:siteResourceId/roles/remove",
verifyApiKeySiteResourceAccess,
verifyApiKeyRoleAccess,
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles),
siteResource.removeRoleFromSiteResource
);
authenticated.post(
"/site-resource/:siteResourceId/users/add",
verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceUsers,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.addUserToSiteResource
);
authenticated.post(
"/site-resource/:siteResourceId/users/remove",
verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceUsers,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.removeUserFromSiteResource
);
authenticated.post(
"/site-resource/:siteResourceId/clients",
verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceClients,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceClients
);
authenticated.post(
"/site-resource/:siteResourceId/clients/add",
verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceClients,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.addClientToSiteResource
);
authenticated.post(
"/site-resource/:siteResourceId/clients/remove",
verifyApiKeySiteResourceAccess,
verifyApiKeySetResourceClients,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.removeClientFromSiteResource
);
authenticated.put(
"/org/:orgId/resource",
verifyApiKeyOrgAccess,
@@ -412,6 +515,42 @@ authenticated.post(
resource.setResourceUsers
);
authenticated.post(
"/resource/:resourceId/roles/add",
verifyApiKeyResourceAccess,
verifyApiKeyRoleAccess,
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles),
resource.addRoleToResource
);
authenticated.post(
"/resource/:resourceId/roles/remove",
verifyApiKeyResourceAccess,
verifyApiKeyRoleAccess,
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles),
resource.removeRoleFromResource
);
authenticated.post(
"/resource/:resourceId/users/add",
verifyApiKeyResourceAccess,
verifyApiKeySetResourceUsers,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
resource.addUserToResource
);
authenticated.post(
"/resource/:resourceId/users/remove",
verifyApiKeyResourceAccess,
verifyApiKeySetResourceUsers,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
resource.removeUserFromResource
);
authenticated.post(
`/resource/:resourceId/password`,
verifyApiKeyResourceAccess,
@@ -657,7 +796,6 @@ authenticated.get(
authenticated.get(
"/org/:orgId/pick-client-defaults",
verifyClientsEnabled,
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.createClient),
client.pickClientDefaults
@@ -665,7 +803,6 @@ authenticated.get(
authenticated.get(
"/org/:orgId/clients",
verifyClientsEnabled,
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.listClients),
client.listClients
@@ -673,7 +810,6 @@ authenticated.get(
authenticated.get(
"/client/:clientId",
verifyClientsEnabled,
verifyApiKeyClientAccess,
verifyApiKeyHasAction(ActionsEnum.getClient),
client.getClient
@@ -681,16 +817,24 @@ authenticated.get(
authenticated.put(
"/org/:orgId/client",
verifyClientsEnabled,
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.createClient),
logActionAudit(ActionsEnum.createClient),
client.createClient
);
// authenticated.put(
// "/org/:orgId/user/:userId/client",
// verifyClientsEnabled,
// verifyApiKeyOrgAccess,
// verifyApiKeyUserAccess,
// verifyApiKeyHasAction(ActionsEnum.createClient),
// logActionAudit(ActionsEnum.createClient),
// client.createUserClient
// );
authenticated.delete(
"/client/:clientId",
verifyClientsEnabled,
verifyApiKeyClientAccess,
verifyApiKeyHasAction(ActionsEnum.deleteClient),
logActionAudit(ActionsEnum.deleteClient),
@@ -699,7 +843,6 @@ authenticated.delete(
authenticated.post(
"/client/:clientId",
verifyClientsEnabled,
verifyApiKeyClientAccess,
verifyApiKeyHasAction(ActionsEnum.updateClient),
logActionAudit(ActionsEnum.updateClient),
@@ -713,3 +856,32 @@ authenticated.put(
logActionAudit(ActionsEnum.applyBlueprint),
blueprints.applyJSONBlueprint
);
authenticated.get(
"/org/:orgId/logs/request",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.viewLogs),
logs.queryRequestAuditLogs
);
authenticated.get(
"/org/:orgId/logs/request/export",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs),
logActionAudit(ActionsEnum.exportLogs),
logs.exportRequestAuditLogs
);
authenticated.get(
"/org/:orgId/logs/analytics",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.viewLogs),
logs.queryRequestAnalytics
);
authenticated.get(
"/org/:orgId/resource-names",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.listResources),
resource.listAllResourceNames
);

View File

@@ -6,15 +6,15 @@ import {
db,
ExitNode,
exitNodes,
resources,
siteResources,
Target,
targets
clientSiteResourcesAssociationsCache
} from "@server/db";
import { clients, clientSites, Newt, sites } from "@server/db";
import { eq, and, inArray } from "drizzle-orm";
import { clients, clientSitesAssociationsCache, Newt, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { updatePeer } from "../olm/peers";
import { sendToExitNode } from "#dynamic/lib/exitNodes";
import { generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip";
import config from "@server/lib/config";
const inputSchema = z.object({
publicKey: z.string(),
@@ -66,7 +66,9 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
// we need to wait for hole punch success
if (!existingSite.endpoint) {
logger.debug(`In newt get config: existing site ${existingSite.siteId} has no endpoint, skipping`);
logger.debug(
`In newt get config: existing site ${existingSite.siteId} has no endpoint, skipping`
);
return;
}
@@ -74,12 +76,12 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
// TODO: somehow we should make sure a recent hole punch has happened if this occurs (hole punch could be from the last restart if done quickly)
}
// if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 6) {
// logger.warn(
// `Site ${existingSite.siteId} last hole punch is too old, skipping`
// );
// return;
// }
if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 5) {
logger.warn(
`handleGetConfigMessage: Site ${existingSite.siteId} last hole punch is too old, skipping`
);
return;
}
// update the endpoint and the public key
const [site] = await db
@@ -132,75 +134,95 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
const clientsRes = await db
.select()
.from(clients)
.innerJoin(clientSites, eq(clients.clientId, clientSites.clientId))
.where(eq(clientSites.siteId, siteId));
.innerJoin(
clientSitesAssociationsCache,
eq(clients.clientId, clientSitesAssociationsCache.clientId)
)
.where(eq(clientSitesAssociationsCache.siteId, siteId));
// Prepare peers data for the response
const peers = await Promise.all(
clientsRes
.filter((client) => {
if (!client.clients.pubKey) {
logger.warn(
`Client ${client.clients.clientId} has no public key, skipping`
);
return false;
}
if (!client.clients.subnet) {
logger.warn(
`Client ${client.clients.clientId} has no subnet, skipping`
);
return false;
}
return true;
})
.map(async (client) => {
// Add or update this peer on the olm if it is connected
try {
if (!site.publicKey) {
logger.warn(
`Site ${site.siteId} has no public key, skipping`
);
return null;
}
let endpoint = site.endpoint;
if (client.clientSites.isRelayed) {
if (!site.exitNodeId) {
logger.warn(
`Site ${site.siteId} has no exit node, skipping`
);
return null;
}
if (!exitNode) {
logger.warn(
`Exit node not found for site ${site.siteId}`
);
return null;
}
endpoint = `${exitNode.endpoint}:21820`;
}
if (!endpoint) {
logger.warn(
`In Newt get config: Peer site ${site.siteId} has no endpoint, skipping`
);
return null;
}
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: site.remoteSubnets
});
} catch (error) {
logger.error(
`Failed to add/update peer ${client.clients.pubKey} to olm ${newt.newtId}: ${error}`
if (!site.publicKey) {
logger.warn(
`Site ${site.siteId} has no public key, skipping`
);
return null;
}
if (!exitNode) {
logger.warn(`Exit node not found for site ${site.siteId}`);
return null;
}
if (!site.endpoint) {
logger.warn(
`Site ${site.siteId} has no endpoint, skipping`
);
return null;
}
// const allSiteResources = await db // only get the site resources that this client has access to
// .select()
// .from(siteResources)
// .innerJoin(
// clientSiteResourcesAssociationsCache,
// eq(
// siteResources.siteResourceId,
// clientSiteResourcesAssociationsCache.siteResourceId
// )
// )
// .where(
// and(
// eq(siteResources.siteId, site.siteId),
// eq(
// clientSiteResourcesAssociationsCache.clientId,
// client.clients.clientId
// )
// )
// );
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: site.endpoint,
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// ),
// aliases: generateAliasConfig(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// )
});
return {
publicKey: client.clients.pubKey!,
allowedIps: [`${client.clients.subnet.split("/")[0]}/32`], // we want to only allow from that client
endpoint: client.clientSites.isRelayed
endpoint: client.clientSitesAssociationsCache.isRelayed
? ""
: client.clientSites.endpoint! // if its relayed it should be localhost
: client.clientSitesAssociationsCache.endpoint! // if its relayed it should be localhost
};
})
);
@@ -208,42 +230,50 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
// Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null);
// Get all enabled targets with their resource protocol information
// Get all enabled site resources for this site
const allSiteResources = await db
.select()
.from(siteResources)
.where(eq(siteResources.siteId, siteId));
const { tcpTargets, udpTargets } = allSiteResources.reduce(
(acc, resource) => {
// Filter out invalid targets
if (!resource.proxyPort || !resource.destinationIp || !resource.destinationPort) {
return acc;
}
const targetsToSend: SubnetProxyTarget[] = [];
// Format target into string
const formattedTarget = `${resource.proxyPort}:${resource.destinationIp}:${resource.destinationPort}`;
for (const resource of allSiteResources) {
// Get clients associated with this specific resource
const resourceClients = await db
.select({
clientId: clients.clientId,
pubKey: clients.pubKey,
subnet: clients.subnet
})
.from(clients)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
clients.clientId,
clientSiteResourcesAssociationsCache.clientId
)
)
.where(
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
resource.siteResourceId
)
);
// Add to the appropriate protocol array
if (resource.protocol === "tcp") {
acc.tcpTargets.push(formattedTarget);
} else {
acc.udpTargets.push(formattedTarget);
}
const resourceTargets = generateSubnetProxyTargets(
resource,
resourceClients
);
return acc;
},
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
);
targetsToSend.push(...resourceTargets);
}
// Build the configuration response
const configResponse = {
ipAddress: site.address,
peers: validPeers,
targets: {
udp: udpTargets,
tcp: tcpTargets
}
targets: targetsToSend
};
logger.debug("Sending config: ", configResponse);

View File

@@ -1,8 +1,8 @@
import { db, exitNodeOrgs, newts } from "@server/db";
import { db, ExitNode, exitNodeOrgs, newts, Transaction } from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { exitNodes, Newt, resources, sites, Target, targets } from "@server/db";
import { targetHealthCheck } from "@server/db";
import { eq, and, sql, inArray } from "drizzle-orm";
import { eq, and, sql, inArray, ne } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger";
import config from "@server/lib/config";
@@ -17,6 +17,7 @@ import {
verifyExitNodeOrgAccess
} from "#dynamic/lib/exitNodes";
import { fetchContainers } from "./dockerSocket";
import { lockManager } from "#dynamic/lib/lock";
export type ExitNodePingResult = {
exitNodeId: number;
@@ -151,27 +152,8 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
return;
}
const sitesQuery = await db
.select({
subnet: sites.subnet
})
.from(sites)
.where(eq(sites.exitNodeId, exitNodeId));
const newSubnet = await getUniqueSubnetForSite(exitNode);
const blockSize = config.getRawConfig().gerbil.site_block_size;
const subnets = sitesQuery
.map((site) => site.subnet)
.filter(
(subnet) =>
subnet && /^(\d{1,3}\.){3}\d{1,3}\/\d{1,2}$/.test(subnet)
)
.filter((subnet) => subnet !== null);
subnets.push(exitNode.address.replace(/\/\d+$/, `/${blockSize}`));
const newSubnet = findNextAvailableCidr(
subnets,
blockSize,
exitNode.address
);
if (!newSubnet) {
logger.error(
`No available subnets found for the new exit node id ${exitNodeId} and site id ${siteId}`
@@ -378,3 +360,39 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
excludeSender: false // Include sender in broadcast
};
};
async function getUniqueSubnetForSite(
exitNode: ExitNode,
trx: Transaction | typeof db = db
): Promise<string | null> {
const lockKey = `subnet-allocation:${exitNode.exitNodeId}`;
return await lockManager.withLock(
lockKey,
async () => {
const sitesQuery = await trx
.select({
subnet: sites.subnet
})
.from(sites)
.where(eq(sites.exitNodeId, exitNode.exitNodeId));
const blockSize = config.getRawConfig().gerbil.site_block_size;
const subnets = sitesQuery
.map((site) => site.subnet)
.filter(
(subnet) =>
subnet && /^(\d{1,3}\.){3}\d{1,3}\/\d{1,2}$/.test(subnet)
)
.filter((subnet) => subnet !== null);
subnets.push(exitNode.address.replace(/\/\d+$/, `/${blockSize}`));
const newSubnet = findNextAvailableCidr(
subnets,
blockSize,
exitNode.address
);
return newSubnet;
},
5000 // 5 second lock TTL - subnet allocation should be quick
);
}

View File

@@ -1,4 +1,4 @@
import { db } from "@server/db";
import { db, Site } from "@server/db";
import { newts, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { sendToClient } from "#dynamic/routers/ws";
@@ -10,65 +10,78 @@ export async function addPeer(
publicKey: string;
allowedIps: string[];
endpoint: string;
}
},
newtId?: string
) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Exit node with ID ${siteId} not found`);
let site: Site | null = null;
if (!newtId) {
[site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Site found for site ${siteId}`);
}
newtId = newt.newtId;
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Site found for site ${siteId}`);
}
sendToClient(newt.newtId, {
await sendToClient(newtId, {
type: "newt/wg/peer/add",
data: peer
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Added peer ${peer.publicKey} to newt ${newt.newtId}`);
logger.info(`Added peer ${peer.publicKey} to newt ${newtId}`);
return site;
}
export async function deletePeer(siteId: number, publicKey: string) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
export async function deletePeer(siteId: number, publicKey: string, newtId?: string) {
let site: Site | null = null;
if (!newtId) {
[site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
newtId = newt.newtId;
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
await sendToClient(newtId, {
type: "newt/wg/peer/remove",
data: {
publicKey
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Deleted peer ${publicKey} from newt ${newt.newtId}`);
logger.info(`Deleted peer ${publicKey} from newt ${newtId}`);
return site;
}
@@ -79,36 +92,43 @@ export async function updatePeer(
peer: {
allowedIps?: string[];
endpoint?: string;
}
},
newtId?: string
) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
let site: Site | null = null;
if (!newtId) {
[site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
newtId = newt.newtId;
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
await sendToClient(newtId, {
type: "newt/wg/peer/update",
data: {
publicKey,
...peer
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Updated peer ${publicKey} on newt ${newt.newtId}`);
logger.info(`Updated peer ${publicKey} on newt ${newtId}`);
return site;
}

View File

@@ -0,0 +1,116 @@
import { NextFunction, Request, Response } from "express";
import { db, olms } from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import moment from "moment";
import { generateId } from "@server/auth/sessions/app";
import { fromError } from "zod-validation-error";
import { hashPassword } from "@server/auth/password";
import { OpenAPITags, registry } from "@server/openApi";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
const bodySchema = z
.object({
name: z.string().min(1).max(255)
})
.strict();
const paramsSchema = z.object({
userId: z.string()
});
export type CreateOlmBody = z.infer<typeof bodySchema>;
export type CreateOlmResponse = {
olmId: string;
secret: string;
};
// registry.registerPath({
// method: "put",
// path: "/user/{userId}/olm",
// description: "Create a new olm for a user.",
// tags: [OpenAPITags.User, OpenAPITags.Client],
// request: {
// body: {
// content: {
// "application/json": {
// schema: bodySchema
// }
// }
// },
// params: paramsSchema
// },
// responses: {}
// });
export async function createUserOlm(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { name } = parsedBody.data;
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { userId } = parsedParams.data;
const olmId = generateId(15);
const secret = generateId(48);
const secretHash = await hashPassword(secret);
await db.transaction(async (trx) => {
await trx.insert(olms).values({
olmId: olmId,
userId,
name,
secretHash,
dateCreated: moment().toISOString()
});
await calculateUserClientsForOrgs(userId, trx);
});
return response<CreateOlmResponse>(res, {
data: {
olmId,
secret
},
success: true,
error: false,
message: "Olm created successfully",
status: HttpCode.OK
});
} catch (e) {
console.error(e);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create olm"
)
);
}
}

View File

@@ -0,0 +1,101 @@
import { NextFunction, Request, Response } from "express";
import { Client, db } from "@server/db";
import { olms, clients, clientSitesAssociationsCache } from "@server/db";
import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { sendTerminateClient } from "../client/terminate";
const paramsSchema = z
.object({
userId: z.string(),
olmId: z.string()
})
.strict();
// registry.registerPath({
// method: "delete",
// path: "/user/{userId}/olm/{olmId}",
// description: "Delete an olm for a user.",
// tags: [OpenAPITags.User, OpenAPITags.Client],
// request: {
// params: paramsSchema
// },
// responses: {}
// });
export async function deleteUserOlm(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { olmId } = parsedParams.data;
// Delete associated clients and the OLM in a transaction
await db.transaction(async (trx) => {
// Find all clients associated with this OLM
const associatedClients = await trx
.select({ clientId: clients.clientId })
.from(clients)
.where(eq(clients.olmId, olmId));
let deletedClient: Client | null = null;
// Delete all associated clients
if (associatedClients.length > 0) {
[deletedClient] = await trx
.delete(clients)
.where(eq(clients.olmId, olmId))
.returning();
}
// Finally, delete the OLM itself
const [olm] = await trx
.delete(olms)
.where(eq(olms.olmId, olmId))
.returning();
if (deletedClient) {
await rebuildClientAssociationsFromClient(deletedClient, trx);
if (olm) {
await sendTerminateClient(
deletedClient.clientId,
olm.olmId
); // the olmId needs to be provided because it cant look it up after deletion
}
}
});
return response(res, {
data: null,
success: true,
error: false,
message: "Device deleted successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to delete device"
)
);
}
}

View File

@@ -1,9 +1,16 @@
import { generateSessionToken } from "@server/auth/sessions/app";
import { db } from "@server/db";
import {
clients,
db,
ExitNode,
exitNodes,
sites,
clientSitesAssociationsCache
} from "@server/db";
import { olms } from "@server/db";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
import { eq } from "drizzle-orm";
import { and, eq, inArray } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
@@ -19,7 +26,8 @@ import config from "@server/lib/config";
export const olmGetTokenBodySchema = z.object({
olmId: z.string(),
secret: z.string(),
token: z.string().optional()
token: z.string().optional(),
orgId: z.string().optional()
});
export type OlmGetTokenBody = z.infer<typeof olmGetTokenBodySchema>;
@@ -40,7 +48,7 @@ export async function getOlmToken(
);
}
const { olmId, secret, token } = parsedBody.data;
const { olmId, secret, token, orgId } = parsedBody.data;
try {
if (token) {
@@ -61,11 +69,12 @@ export async function getOlmToken(
}
}
const existingOlmRes = await db
const [existingOlm] = await db
.select()
.from(olms)
.where(eq(olms.olmId, olmId));
if (!existingOlmRes || !existingOlmRes.length) {
if (!existingOlm) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
@@ -74,12 +83,11 @@ export async function getOlmToken(
);
}
const existingOlm = existingOlmRes[0];
const validSecret = await verifyPassword(
secret,
existingOlm.secretHash
);
if (!validSecret) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
@@ -96,11 +104,115 @@ export async function getOlmToken(
const resToken = generateSessionToken();
await createOlmSession(resToken, existingOlm.olmId);
let orgIdToUse = orgId;
let clientIdToUse;
if (!orgIdToUse) {
if (!existingOlm.clientId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Olm is not associated with a client, orgId is required"
)
);
}
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, existingOlm.clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Olm's associated client not found, orgId is required"
)
);
}
orgIdToUse = client.orgId;
clientIdToUse = client.clientId;
} else {
// we did provide the org
const [client] = await db
.select()
.from(clients)
.where(
and(eq(clients.orgId, orgIdToUse), eq(clients.olmId, olmId))
) // we want to lock on to the client with this olmId otherwise it can get assigned to a random one
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No client found for provided orgId"
)
);
}
if (existingOlm.clientId !== client.clientId) {
// we only need to do this if the client is changing
logger.debug(
`Switching olm client ${existingOlm.olmId} to org ${orgId} for user ${existingOlm.userId}`
);
await db
.update(olms)
.set({
clientId: client.clientId
})
.where(eq(olms.olmId, existingOlm.olmId));
}
clientIdToUse = client.clientId;
}
// Get all exit nodes from sites where the client has peers
const clientSites = await db
.select()
.from(clientSitesAssociationsCache)
.innerJoin(
sites,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!));
// Extract unique exit node IDs
const exitNodeIds = Array.from(
new Set(
clientSites
.map(({ sites: site }) => site.exitNodeId)
.filter((id): id is number => id !== null)
)
);
let allExitNodes: ExitNode[] = [];
if (exitNodeIds.length > 0) {
allExitNodes = await db
.select()
.from(exitNodes)
.where(inArray(exitNodes.exitNodeId, exitNodeIds));
}
const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => {
return {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
};
});
logger.debug("Token created successfully");
return response<{ token: string }>(res, {
return response<{
token: string;
exitNodes: { publicKey: string; endpoint: string }[];
}>(res, {
data: {
token: resToken
token: resToken,
exitNodes: exitNodesHpData
},
success: true,
error: false,

View File

@@ -0,0 +1,70 @@
import { NextFunction, Request, Response } from "express";
import { db } from "@server/db";
import { olms } from "@server/db";
import { eq, and } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
const paramsSchema = z
.object({
userId: z.string(),
olmId: z.string()
})
.strict();
// registry.registerPath({
// method: "get",
// path: "/user/{userId}/olm/{olmId}",
// description: "Get an olm for a user.",
// tags: [OpenAPITags.User, OpenAPITags.Client],
// request: {
// params: paramsSchema
// },
// responses: {}
// });
export async function getUserOlm(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { olmId, userId } = parsedParams.data;
const [olm] = await db
.select()
.from(olms)
.where(and(eq(olms.userId, userId), eq(olms.olmId, olmId)));
return response(res, {
data: olm,
success: true,
error: false,
message: "Successfully retrieved olm",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to retrieve olm"
)
);
}
}

View File

@@ -1,8 +1,12 @@
import { db } from "@server/db";
import { disconnectClient } from "#dynamic/routers/ws";
import { MessageHandler } from "@server/routers/ws";
import { clients, Olm } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm";
import logger from "@server/logger";
import { validateSessionToken } from "@server/auth/sessions/app";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { sendTerminateClient } from "../client/terminate";
// Track if the offline checker interval is running
let offlineCheckerInterval: NodeJS.Timeout | null = null;
@@ -20,10 +24,14 @@ export const startOlmOfflineChecker = (): void => {
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = Math.floor((Date.now() - OFFLINE_THRESHOLD_MS) / 1000);
const twoMinutesAgo = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
);
// TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
await db
const offlineClients = await db
.update(clients)
.set({ online: false })
.where(
@@ -34,8 +42,34 @@ export const startOlmOfflineChecker = (): void => {
isNull(clients.lastPing)
)
)
)
.returning();
for (const offlineClient of offlineClients) {
logger.info(
`Kicking offline olm client ${offlineClient.clientId} due to inactivity`
);
if (!offlineClient.olmId) {
logger.warn(
`Offline client ${offlineClient.clientId} has no olmId, cannot disconnect`
);
continue;
}
// Send a disconnect message to the client if connected
try {
await sendTerminateClient(offlineClient.clientId, offlineClient.olmId); // terminate first
// wait a moment to ensure the message is sent
await new Promise(resolve => setTimeout(resolve, 1000));
await disconnectClient(offlineClient.olmId);
} catch (error) {
logger.error(
`Error sending disconnect to offline olm ${offlineClient.clientId}`,
{ error }
);
}
}
} catch (error) {
logger.error("Error in offline checker interval", { error });
}
@@ -62,11 +96,57 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
const { userToken } = message.data;
if (!olm) {
logger.warn("Olm not found");
return;
}
if (olm.userId) {
// we need to check a user token to make sure its still valid
const { session: userSession, user } =
await validateSessionToken(userToken);
if (!userSession || !user) {
logger.warn("Invalid user session for olm ping");
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
}
if (user.userId !== olm.userId) {
logger.warn("User ID mismatch for olm ping");
return;
}
// get the client
const [client] = await db
.select()
.from(clients)
.where(
and(
eq(clients.olmId, olm.olmId),
eq(clients.userId, olm.userId)
)
)
.limit(1);
if (!client) {
logger.warn("Client not found for olm ping");
return;
}
const policyCheck = await checkOrgAccessPolicy({
orgId: client.orgId,
userId: olm.userId,
session: userToken // this is the user token passed in the message
});
if (!policyCheck.allowed) {
logger.warn(
`Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}`
);
return;
}
}
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
@@ -78,7 +158,7 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
.update(clients)
.set({
lastPing: Math.floor(Date.now() / 1000),
online: true,
online: true
})
.where(eq(clients.clientId, olm.clientId));
} catch (error) {
@@ -89,7 +169,7 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
message: {
type: "pong",
data: {
timestamp: new Date().toISOString(),
timestamp: new Date().toISOString()
}
},
broadcast: false,

View File

@@ -1,31 +1,115 @@
import { db, ExitNode } from "@server/db";
import {
Client,
clientSiteResourcesAssociationsCache,
db,
ExitNode,
Org,
orgs,
roleClients,
roles,
siteResources,
Transaction,
userClients,
userOrgs,
users
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import {
clients,
clientSitesAssociationsCache,
exitNodes,
Olm,
olms,
sites
} from "@server/db";
import { and, eq, inArray, isNull } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import {
generateAliasConfig,
getNextAvailableClientSubnet
} from "@server/lib/ip";
import { generateRemoteSubnets } from "@server/lib/ip";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app";
import config from "@server/lib/config";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
const now = new Date().getTime() / 1000;
const now = Math.floor(Date.now() / 1000);
if (!olm) {
logger.warn("Olm not found");
return;
}
const { publicKey, relay, olmVersion, olmAgent, orgId, userToken } = message.data;
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
logger.warn("Olm client ID not found");
return;
}
const clientId = olm.clientId;
const { publicKey, relay, olmVersion } = message.data;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, olm.clientId))
.limit(1);
if (!client) {
logger.warn("Client ID not found");
return;
}
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, client.orgId))
.limit(1);
if (!org) {
logger.warn("Org not found");
return;
}
if (orgId) {
if (!olm.userId) {
logger.warn("Olm has no user ID");
return;
}
const { session: userSession, user } =
await validateSessionToken(userToken);
if (!userSession || !user) {
logger.warn("Invalid user session for olm register");
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
}
if (user.userId !== olm.userId) {
logger.warn("User ID mismatch for olm register");
return;
}
const policyCheck = await checkOrgAccessPolicy({
orgId: orgId,
userId: olm.userId,
session: userToken // this is the user token passed in the message
});
if (!policyCheck.allowed) {
logger.warn(
`Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`
);
return;
}
}
logger.debug(
`Olm client ID: ${clientId}, Public Key: ${publicKey}, Relay: ${relay}`
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
);
if (!publicKey) {
@@ -33,66 +117,16 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return;
}
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
if (client.exitNodeId) {
// TODO: FOR NOW WE ARE JUST HOLEPUNCHING ALL EXIT NODES BUT IN THE FUTURE WE SHOULD HANDLE THIS BETTER
// Get the exit node
const allExitNodes = await listExitNodes(client.orgId, true); // FILTER THE ONLINE ONES
const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => {
return {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
};
});
// Send holepunch message
await sendToClient(olm.olmId, {
type: "olm/wg/holepunch/all",
data: {
exitNodes: exitNodesHpData
}
});
if (!olmVersion) {
// THIS IS FOR BACKWARDS COMPATIBILITY
// THE OLDER CLIENTS DID NOT SEND THE VERSION
await sendToClient(olm.olmId, {
type: "olm/wg/holepunch",
data: {
serverPubKey: allExitNodes[0].publicKey,
endpoint: allExitNodes[0].endpoint
}
});
}
}
if (olmVersion) {
if ((olmVersion && olm.version !== olmVersion) || (olmAgent && olm.agent !== olmAgent)) {
await db
.update(olms)
.set({
version: olmVersion
version: olmVersion,
agent: olmAgent
})
.where(eq(olms.olmId, olm.olmId));
}
// if (now - (client.lastHolePunch || 0) > 6) {
// logger.warn("Client last hole punch is too old, skipping all sites");
// return;
// }
if (client.pubKey !== publicKey) {
logger.info(
"Public key mismatch. Updating public key and clearing session info..."
@@ -103,23 +137,26 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.set({
pubKey: publicKey
})
.where(eq(clients.clientId, olm.clientId));
.where(eq(clients.clientId, client.clientId));
// set isRelay to false for all of the client's sites to reset the connection metadata
await db
.update(clientSites)
.update(clientSitesAssociationsCache)
.set({
isRelayed: relay == true
})
.where(eq(clientSites.clientId, olm.clientId));
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
}
// Get all sites data
const sitesData = await db
.select()
.from(sites)
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
.where(eq(clientSites.clientId, client.clientId));
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Prepare an array to store site configurations
const siteConfigurations = [];
@@ -127,15 +164,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
`Found ${sitesData.length} sites for client ${client.clientId}`
);
if (sitesData.length === 0) {
sendToClient(olm.olmId, {
type: "olm/register/no-sites",
data: {}
});
// this prevents us from accepting a register from an olm that has not hole punched yet.
// the olm will pump the register so we can keep checking
// TODO: I still think there is a better way to do this rather than locking it out here but ???
if (now - (client.lastHolePunch || 0) > 5 && sitesData.length > 0) {
logger.warn(
"Client last hole punch is too old and we have sites to send; skipping this register"
);
return;
}
// Process each site
for (const { sites: site } of sitesData) {
for (const { sites: site, clientSitesAssociationsCache: association } of sitesData) {
if (!site.exitNodeId) {
logger.warn(
`Site ${site.siteId} does not have exit node, skipping`
@@ -145,7 +185,9 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// Validate endpoint and hole punch status
if (!site.endpoint) {
logger.warn(`In olm register: site ${site.siteId} has no endpoint, skipping`);
logger.warn(
`In olm register: site ${site.siteId} has no endpoint, skipping`
);
continue;
}
@@ -171,11 +213,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
const [clientSite] = await db
.select()
.from(clientSites)
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSites.clientId, client.clientId),
eq(clientSites.siteId, site.siteId)
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
)
.limit(1);
@@ -196,7 +238,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
);
}
let endpoint = site.endpoint;
let relayEndpoint: string | undefined = undefined;
if (relay) {
const [exitNode] = await db
.select()
@@ -207,17 +249,43 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.warn(`Exit node not found for site ${site.siteId}`);
continue;
}
endpoint = `${exitNode.endpoint}:21820`;
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
}
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
endpoint: endpoint,
// relayEndpoint: relayEndpoint, // this can be undefined now if not relayed // lets not do this for now because it would conflict with the hole punch testing
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: site.remoteSubnets
remoteSubnets: generateRemoteSubnets(
allSiteResources.map(({ siteResources }) => siteResources)
),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
});
}
@@ -233,7 +301,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
type: "olm/wg/connect",
data: {
sites: siteConfigurations,
tunnelIP: client.subnet
tunnelIP: client.subnet,
utilitySubnet: org.utilitySubnet
}
},
broadcast: false,

View File

@@ -1,8 +1,8 @@
import { db, exitNodes, sites } from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, clientSites, Olm } from "@server/db";
import { clients, clientSitesAssociationsCache, Olm } from "@server/db";
import { and, eq } from "drizzle-orm";
import { updatePeer } from "../newt/peers";
import { updatePeer as newtUpdatePeer } from "../newt/peers";
import logger from "@server/logger";
export const handleOlmRelayMessage: MessageHandler = async (context) => {
@@ -67,30 +67,31 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
}
await db
.update(clientSites)
.update(clientSitesAssociationsCache)
.set({
isRelayed: true
})
.where(
and(
eq(clientSites.clientId, olm.clientId),
eq(clientSites.siteId, siteId)
eq(clientSitesAssociationsCache.clientId, olm.clientId),
eq(clientSitesAssociationsCache.siteId, siteId)
)
);
// update the peer on the exit node
await updatePeer(siteId, client.pubKey, {
endpoint: "" // this removes the endpoint
await newtUpdatePeer(siteId, client.pubKey, {
endpoint: "" // this removes the endpoint so the exit node knows to relay
});
sendToClient(olm.olmId, {
type: "olm/wg/peer/relay",
data: {
siteId: siteId,
endpoint: exitNode.endpoint,
publicKey: exitNode.publicKey
}
});
return;
return {
message: {
type: "olm/wg/peer/relay",
data: {
siteId: siteId,
relayEndpoint: exitNode.endpoint
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -0,0 +1,187 @@
import {
Client,
clientSiteResourcesAssociationsCache,
db,
ExitNode,
Org,
orgs,
roleClients,
roles,
siteResources,
Transaction,
userClients,
userOrgs,
users
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import {
clients,
clientSitesAssociationsCache,
exitNodes,
Olm,
olms,
sites
} from "@server/db";
import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import {
generateAliasConfig,
getNextAvailableClientSubnet
} from "@server/lib/ip";
import { generateRemoteSubnets } from "@server/lib/ip";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import {
addPeer as newtAddPeer,
deletePeer as newtDeletePeer
} from "@server/routers/newt/peers";
export const handleOlmServerPeerAddMessage: MessageHandler = async (
context
) => {
logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
const now = Math.floor(Date.now() / 1000);
if (!olm) {
logger.warn("Olm not found");
return;
}
const { siteId } = message.data;
// get the site
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${siteId} not found`
);
return;
}
if (!site.endpoint) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${siteId} has no endpoint`
);
return;
}
// get the client
if (!olm.clientId) {
logger.error(
`handleOlmServerPeerAddMessage: Olm with ID ${olm.olmId} has no clientId`
);
return;
}
const [client] = await db
.select()
.from(clients)
.where(and(eq(clients.clientId, olm.clientId)))
.limit(1);
if (!client) {
logger.error(
`handleOlmServerPeerAddMessage: Client with ID ${olm.clientId} not found`
);
return;
}
if (!client.pubKey) {
logger.error(
`handleOlmServerPeerAddMessage: Client with ID ${client.clientId} has no public key`
);
return;
}
let endpoint: string | null = null;
// TODO: should we pick only the one from the site its talking to instead of any good current session?
const currentSessionSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
isNotNull(clientSitesAssociationsCache.endpoint),
eq(clientSitesAssociationsCache.publicKey, client.pubKey) // limit it to the current session its connected with otherwise the endpoint could be stale
)
);
// pick an endpoint
for (const assoc of currentSessionSiteAssociationCaches) {
if (assoc.endpoint) {
endpoint = assoc.endpoint;
break;
}
}
if (!endpoint) {
logger.error(
`handleOlmServerPeerAddMessage: No endpoint found for client ${client.clientId}`
);
return;
}
// NOTE: here we are always starting direct to the peer and will relay later
await newtAddPeer(siteId, {
publicKey: client.pubKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
endpoint: endpoint // this is the client's endpoint with reference to the site's exit node
});
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
// Return connect message with all site configurations
return {
message: {
type: "olm/wg/peer/add",
data: {
siteId: site.siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: generateRemoteSubnets(
allSiteResources.map(({ siteResources }) => siteResources)
),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -0,0 +1,96 @@
import { db, exitNodes, sites } from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, clientSitesAssociationsCache, Olm } from "@server/db";
import { and, eq } from "drizzle-orm";
import { updatePeer as newtUpdatePeer } from "../newt/peers";
import logger from "@server/logger";
export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
logger.info("Handling unrelay olm message!");
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
return;
}
const clientId = olm.clientId;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
// make sure we hand endpoints for both the site and the client and the lastHolePunch is not too old
if (!client.pubKey) {
logger.warn("Client has no endpoint or listen port");
return;
}
const { siteId } = message.data;
// Get the site
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
logger.warn("Site not found or has no exit node");
return;
}
const [clientSiteAssociation] = await db
.update(clientSitesAssociationsCache)
.set({
isRelayed: false
})
.where(
and(
eq(clientSitesAssociationsCache.clientId, olm.clientId),
eq(clientSitesAssociationsCache.siteId, siteId)
)
)
.returning();
if (!clientSiteAssociation) {
logger.warn("Client-Site association not found");
return;
}
if (!clientSiteAssociation.endpoint) {
logger.warn("Client-Site association has no endpoint, cannot unrelay");
return;
}
// update the peer on the exit node
await newtUpdatePeer(siteId, client.pubKey, {
endpoint: clientSiteAssociation.endpoint // this is the endpoint of the client to connect directly to the exit node
});
return {
message: {
type: "olm/wg/peer/unrelay",
data: {
siteId: siteId,
endpoint: site.endpoint
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -1,5 +1,11 @@
export * from "./handleOlmRegisterMessage";
export * from "./getOlmToken";
export * from "./createOlm";
export * from "./createUserOlm";
export * from "./handleOlmRelayMessage";
export * from "./handleOlmPingMessage";
export * from "./handleOlmPingMessage";
export * from "./deleteUserOlm";
export * from "./listUserOlms";
export * from "./deleteUserOlm";
export * from "./getUserOlm";
export * from "./handleOlmServerPeerAddMessage";
export * from "./handleOlmUnRelayMessage";

View File

@@ -0,0 +1,139 @@
import { NextFunction, Request, Response } from "express";
import { db } from "@server/db";
import { olms } from "@server/db";
import { eq, count, desc } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
const querySchema = z.object({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.number().int().positive()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.number().int().nonnegative())
});
const paramsSchema = z
.object({
userId: z.string()
})
.strict();
// registry.registerPath({
// method: "delete",
// path: "/user/{userId}/olms",
// description: "List all olms for a user.",
// tags: [OpenAPITags.User, OpenAPITags.Client],
// request: {
// query: querySchema,
// params: paramsSchema
// },
// responses: {}
// });
export type ListUserOlmsResponse = {
olms: Array<{
olmId: string;
dateCreated: string;
version: string | null;
name: string | null;
clientId: number | null;
userId: string | null;
}>;
pagination: {
total: number;
limit: number;
offset: number;
};
};
export async function listUserOlms(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedQuery = querySchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error).toString()
)
);
}
const { limit, offset } = parsedQuery.data;
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { userId } = parsedParams.data;
// Get total count
const [totalCountResult] = await db
.select({ count: count() })
.from(olms)
.where(eq(olms.userId, userId));
const total = totalCountResult?.count || 0;
// Get OLMs for the current user
const userOlms = await db
.select({
olmId: olms.olmId,
dateCreated: olms.dateCreated,
version: olms.version,
name: olms.name,
clientId: olms.clientId,
userId: olms.userId
})
.from(olms)
.where(eq(olms.userId, userId))
.orderBy(desc(olms.dateCreated))
.limit(limit)
.offset(offset);
return response<ListUserOlmsResponse>(res, {
data: {
olms: userOlms,
pagination: {
total,
limit,
offset
}
},
success: true,
error: false,
message: "Olms retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to list OLMs"
)
);
}
}

View File

@@ -3,6 +3,7 @@ import { clients, olms, newts, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger";
import { Alias } from "yaml";
export async function addPeer(
clientId: number,
@@ -10,54 +11,74 @@ export async function addPeer(
siteId: number;
publicKey: string;
endpoint: string;
relayEndpoint: string;
serverIP: string | null;
serverPort: number | null;
remoteSubnets: string | null; // optional, comma-separated list of subnets that this site can access
}
remoteSubnets: string[] | null; // optional, comma-separated list of subnets that this site can access
aliases: Alias[];
},
olmId?: string
) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
if (!olmId) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
return; // ignore this because an olm might not be associated with the client anymore
}
olmId = olm.olmId;
}
await sendToClient(olm.olmId, {
await sendToClient(olmId, {
type: "olm/wg/peer/add",
data: {
siteId: peer.siteId,
publicKey: peer.publicKey,
endpoint: peer.endpoint,
relayEndpoint: peer.relayEndpoint,
serverIP: peer.serverIP,
serverPort: peer.serverPort,
remoteSubnets: peer.remoteSubnets // optional, comma-separated list of subnets that this site can access
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
aliases: peer.aliases
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`);
logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`);
}
export async function deletePeer(clientId: number, siteId: number, publicKey: string) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
export async function deletePeer(
clientId: number,
siteId: number,
publicKey: string,
olmId?: string
) {
if (!olmId) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
return;
}
olmId = olm.olmId;
}
await sendToClient(olm.olmId, {
await sendToClient(olmId, {
type: "olm/wg/peer/remove",
data: {
publicKey,
siteId: siteId
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Deleted peer ${publicKey} from olm ${olm.olmId}`);
logger.info(`Deleted peer ${publicKey} from olm ${olmId}`);
}
export async function updatePeer(
@@ -66,31 +87,80 @@ export async function updatePeer(
siteId: number;
publicKey: string;
endpoint: string;
serverIP: string | null;
serverPort: number | null;
remoteSubnets?: string | null; // optional, comma-separated list of subnets that
}
relayEndpoint?: string;
serverIP?: string | null;
serverPort?: number | null;
remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that
aliases?: Alias[] | null;
},
olmId?: string
) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
if (!olmId) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
return
}
olmId = olm.olmId;
}
await sendToClient(olm.olmId, {
await sendToClient(olmId, {
type: "olm/wg/peer/update",
data: {
siteId: peer.siteId,
publicKey: peer.publicKey,
endpoint: peer.endpoint,
relayEndpoint: peer.relayEndpoint,
serverIP: peer.serverIP,
serverPort: peer.serverPort,
remoteSubnets: peer.remoteSubnets
remoteSubnets: peer.remoteSubnets,
aliases: peer.aliases
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`);
logger.info(`Updated peer ${peer.publicKey} on olm ${olmId}`);
}
export async function initPeerAddHandshake(
clientId: number,
peer: {
siteId: number;
exitNode: {
publicKey: string;
endpoint: string;
};
},
olmId?: string
) {
if (!olmId) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
return;
}
olmId = olm.olmId;
}
await sendToClient(olmId, {
type: "olm/wg/peer/holepunch/site/add",
data: {
siteId: peer.siteId,
exitNode: {
publicKey: peer.exitNode.publicKey,
endpoint: peer.exitNode.endpoint
}
}
}).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.info(`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`);
}

View File

@@ -26,12 +26,13 @@ import { createCustomer } from "#dynamic/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
const createOrgSchema = z.strictObject({
orgId: z.string(),
name: z.string().min(1).max(255),
subnet: z.string()
});
orgId: z.string(),
name: z.string().min(1).max(255),
subnet: z.string()
});
registry.registerPath({
method: "put",
@@ -131,12 +132,16 @@ export async function createOrg(
.from(domains)
.where(eq(domains.configManaged, true));
const utilitySubnet =
config.getRawConfig().orgs.utility_subnet_group;
const newOrg = await trx
.insert(orgs)
.values({
orgId,
name,
subnet,
utilitySubnet,
createdAt: new Date().toISOString()
})
.returning();
@@ -190,6 +195,7 @@ export async function createOrg(
);
}
let ownerUserId: string | null = null;
if (req.user) {
await trx.insert(userOrgs).values({
userId: req.user!.userId,
@@ -197,6 +203,7 @@ export async function createOrg(
roleId: roleId,
isOwner: true
});
ownerUserId = req.user!.userId;
} else {
// if org created by root api key, set the server admin as the owner
const [serverAdmin] = await trx
@@ -216,6 +223,7 @@ export async function createOrg(
roleId: roleId,
isOwner: true
});
ownerUserId = serverAdmin.userId;
}
const memberRole = await trx
@@ -234,6 +242,8 @@ export async function createOrg(
orgId
}))
);
await calculateUserClientsForOrgs(ownerUserId, trx);
});
if (!org) {

View File

@@ -1,6 +1,15 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, domains, orgDomains, resources } from "@server/db";
import {
clients,
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
domains,
olms,
orgDomains,
resources
} from "@server/db";
import { newts, newtSessions, orgs, sites, userActions } from "@server/db";
import { eq, and, inArray, sql } from "drizzle-orm";
import response from "@server/lib/response";
@@ -14,8 +23,8 @@ import { deletePeer } from "../gerbil/peers";
import { OpenAPITags, registry } from "@server/openApi";
const deleteOrgSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export type DeleteOrgResponse = {};
@@ -69,41 +78,75 @@ export async function deleteOrg(
.where(eq(sites.orgId, orgId))
.limit(1);
const orgClients = await db
.select()
.from(clients)
.where(eq(clients.orgId, orgId));
const deletedNewtIds: string[] = [];
const olmsToTerminate: string[] = [];
await db.transaction(async (trx) => {
if (sites) {
for (const site of orgSites) {
if (site.pubKey) {
if (site.type == "wireguard") {
await deletePeer(site.exitNodeId!, site.pubKey);
} else if (site.type == "newt") {
// get the newt on the site by querying the newt table for siteId
const [deletedNewt] = await trx
.delete(newts)
.where(eq(newts.siteId, site.siteId))
.returning();
if (deletedNewt) {
deletedNewtIds.push(deletedNewt.newtId);
for (const site of orgSites) {
if (site.pubKey) {
if (site.type == "wireguard") {
await deletePeer(site.exitNodeId!, site.pubKey);
} else if (site.type == "newt") {
// get the newt on the site by querying the newt table for siteId
const [deletedNewt] = await trx
.delete(newts)
.where(eq(newts.siteId, site.siteId))
.returning();
if (deletedNewt) {
deletedNewtIds.push(deletedNewt.newtId);
// delete all of the sessions for the newt
await trx
.delete(newtSessions)
.where(
eq(
newtSessions.newtId,
deletedNewt.newtId
)
);
}
// delete all of the sessions for the newt
await trx
.delete(newtSessions)
.where(
eq(newtSessions.newtId, deletedNewt.newtId)
);
}
}
logger.info(`Deleting site ${site.siteId}`);
await trx
.delete(sites)
.where(eq(sites.siteId, site.siteId));
}
logger.info(`Deleting site ${site.siteId}`);
await trx.delete(sites).where(eq(sites.siteId, site.siteId));
}
for (const client of orgClients) {
const [olm] = await trx
.select()
.from(olms)
.where(eq(olms.clientId, client.clientId))
.limit(1);
if (olm) {
olmsToTerminate.push(olm.olmId);
}
logger.info(`Deleting client ${client.clientId}`);
await trx
.delete(clients)
.where(eq(clients.clientId, client.clientId));
// also delete the associations
await trx
.delete(clientSiteResourcesAssociationsCache)
.where(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
);
await trx
.delete(clientSitesAssociationsCache)
.where(
eq(
clientSitesAssociationsCache.clientId,
client.clientId
)
);
}
const allOrgDomains = await trx
@@ -150,7 +193,7 @@ export async function deleteOrg(
// Send termination messages outside of transaction to prevent blocking
for (const newtId of deletedNewtIds) {
const payload = {
type: `newt/terminate`,
type: `newt/wg/terminate`,
data: {}
};
// Don't await this to prevent blocking the response
@@ -162,6 +205,18 @@ export async function deleteOrg(
});
}
for (const olmId of olmsToTerminate) {
sendToClient(olmId, {
type: "olm/terminate",
data: {}
}).catch((error) => {
logger.error(
"Failed to send termination message to olm:",
error
);
});
}
return response(res, {
data: null,
success: true,

View File

@@ -130,7 +130,7 @@ export async function getOrgOverview(
numSites,
numUsers,
numResources,
isAdmin: role.name === "Admin",
isAdmin: role.isAdmin || false,
isOwner: req.userOrg?.isOwner || false
},
success: true,

View File

@@ -0,0 +1,161 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, resources } from "@server/db";
import { roleResources, roles } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, and } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
const addRoleToResourceBodySchema = z
.object({
roleId: z.number().int().positive()
})
.strict();
const addRoleToResourceParamsSchema = z
.object({
resourceId: z
.string()
.transform(Number)
.pipe(z.number().int().positive())
})
.strict();
registry.registerPath({
method: "post",
path: "/resource/{resourceId}/roles/add",
description: "Add a single role to a resource.",
tags: [OpenAPITags.Resource, OpenAPITags.Role],
request: {
params: addRoleToResourceParamsSchema,
body: {
content: {
"application/json": {
schema: addRoleToResourceBodySchema
}
}
}
},
responses: {}
});
export async function addRoleToResource(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = addRoleToResourceBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { roleId } = parsedBody.data;
const parsedParams = addRoleToResourceParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { resourceId } = parsedParams.data;
// get the resource
const [resource] = await db
.select()
.from(resources)
.where(eq(resources.resourceId, resourceId))
.limit(1);
if (!resource) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Resource not found")
);
}
// verify the role exists and belongs to the same org
const [role] = await db
.select()
.from(roles)
.where(
and(
eq(roles.roleId, roleId),
eq(roles.orgId, resource.orgId)
)
)
.limit(1);
if (!role) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Role not found or does not belong to the same organization"
)
);
}
// Check if the role is an admin role
if (role.isAdmin) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Admin role cannot be assigned to resources"
)
);
}
// Check if role already exists in resource
const existingEntry = await db
.select()
.from(roleResources)
.where(
and(
eq(roleResources.resourceId, resourceId),
eq(roleResources.roleId, roleId)
)
);
if (existingEntry.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Role already assigned to resource"
)
);
}
await db.insert(roleResources).values({
roleId,
resourceId
});
return response(res, {
data: {},
success: true,
error: false,
message: "Role added to resource successfully",
status: HttpCode.CREATED
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,130 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, resources } from "@server/db";
import { userResources } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, and } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
const addUserToResourceBodySchema = z
.object({
userId: z.string()
})
.strict();
const addUserToResourceParamsSchema = z
.object({
resourceId: z
.string()
.transform(Number)
.pipe(z.number().int().positive())
})
.strict();
registry.registerPath({
method: "post",
path: "/resource/{resourceId}/users/add",
description: "Add a single user to a resource.",
tags: [OpenAPITags.Resource, OpenAPITags.User],
request: {
params: addUserToResourceParamsSchema,
body: {
content: {
"application/json": {
schema: addUserToResourceBodySchema
}
}
}
},
responses: {}
});
export async function addUserToResource(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = addUserToResourceBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { userId } = parsedBody.data;
const parsedParams = addUserToResourceParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { resourceId } = parsedParams.data;
// get the resource
const [resource] = await db
.select()
.from(resources)
.where(eq(resources.resourceId, resourceId))
.limit(1);
if (!resource) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Resource not found")
);
}
// Check if user already exists in resource
const existingEntry = await db
.select()
.from(userResources)
.where(
and(
eq(userResources.resourceId, resourceId),
eq(userResources.userId, userId)
)
);
if (existingEntry.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
"User already assigned to resource"
)
);
}
await db.insert(userResources).values({
userId,
resourceId
});
return response(res, {
data: {},
success: true,
error: false,
message: "User added to resource successfully",
status: HttpCode.CREATED
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -25,4 +25,8 @@ export * from "./getUserResources";
export * from "./setResourceHeaderAuth";
export * from "./addEmailToResourceWhitelist";
export * from "./removeEmailFromResourceWhitelist";
export * from "./addRoleToResource";
export * from "./removeRoleFromResource";
export * from "./addUserToResource";
export * from "./removeUserFromResource";
export * from "./listAllResourceNames";

Some files were not shown because too many files have changed in this diff Show More