mirror of
https://github.com/fosrl/pangolin.git
synced 2026-02-15 09:26:40 +00:00
Merge branch 'dev' into feat/login-page-customization
This commit is contained in:
@@ -86,6 +86,7 @@ export enum ActionsEnum {
|
||||
updateOrgDomain = "updateOrgDomain",
|
||||
getDNSRecords = "getDNSRecords",
|
||||
createNewt = "createNewt",
|
||||
createOlm = "createOlm",
|
||||
createIdp = "createIdp",
|
||||
updateIdp = "updateIdp",
|
||||
deleteIdp = "deleteIdp",
|
||||
|
||||
@@ -36,13 +36,15 @@ export async function createSession(
|
||||
const sessionId = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(token))
|
||||
);
|
||||
const session: Session = {
|
||||
sessionId: sessionId,
|
||||
userId,
|
||||
expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(),
|
||||
issuedAt: new Date().getTime()
|
||||
};
|
||||
await db.insert(sessions).values(session);
|
||||
const [session] = await db
|
||||
.insert(sessions)
|
||||
.values({
|
||||
sessionId: sessionId,
|
||||
userId,
|
||||
expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(),
|
||||
issuedAt: new Date().getTime()
|
||||
})
|
||||
.returning();
|
||||
return session;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,43 @@
|
||||
import { Request } from "express";
|
||||
import { validateSessionToken, SESSION_COOKIE_NAME } from "@server/auth/sessions/app";
|
||||
import {
|
||||
validateSessionToken,
|
||||
SESSION_COOKIE_NAME
|
||||
} from "@server/auth/sessions/app";
|
||||
|
||||
export async function verifySession(req: Request) {
|
||||
export async function verifySession(req: Request, forceLogin?: boolean) {
|
||||
const res = await validateSessionToken(
|
||||
req.cookies[SESSION_COOKIE_NAME] ?? "",
|
||||
req.cookies[SESSION_COOKIE_NAME] ?? ""
|
||||
);
|
||||
|
||||
if (!forceLogin) {
|
||||
return res;
|
||||
}
|
||||
if (!res.session || !res.user) {
|
||||
return {
|
||||
session: null,
|
||||
user: null
|
||||
};
|
||||
}
|
||||
if (res.session.deviceAuthUsed) {
|
||||
return {
|
||||
session: null,
|
||||
user: null
|
||||
};
|
||||
}
|
||||
if (!res.session.issuedAt) {
|
||||
return {
|
||||
session: null,
|
||||
user: null
|
||||
};
|
||||
}
|
||||
const mins = 5 * 60 * 1000;
|
||||
const now = new Date().getTime();
|
||||
if (now - res.session.issuedAt > mins) {
|
||||
return {
|
||||
session: null,
|
||||
user: null
|
||||
};
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -42,11 +42,17 @@ export async function getUniqueResourceName(orgId: string): Promise<string> {
|
||||
}
|
||||
|
||||
const name = generateName();
|
||||
const count = await db
|
||||
.select({ niceId: resources.niceId, orgId: resources.orgId })
|
||||
.from(resources)
|
||||
.where(and(eq(resources.niceId, name), eq(resources.orgId, orgId)));
|
||||
if (count.length === 0) {
|
||||
const [resourceCount, siteResourceCount] = await Promise.all([
|
||||
db
|
||||
.select({ niceId: resources.niceId, orgId: resources.orgId })
|
||||
.from(resources)
|
||||
.where(and(eq(resources.niceId, name), eq(resources.orgId, orgId))),
|
||||
db
|
||||
.select({ niceId: siteResources.niceId, orgId: siteResources.orgId })
|
||||
.from(siteResources)
|
||||
.where(and(eq(siteResources.niceId, name), eq(siteResources.orgId, orgId)))
|
||||
]);
|
||||
if (resourceCount.length === 0 && siteResourceCount.length === 0) {
|
||||
return name;
|
||||
}
|
||||
loops++;
|
||||
@@ -61,11 +67,17 @@ export async function getUniqueSiteResourceName(orgId: string): Promise<string>
|
||||
}
|
||||
|
||||
const name = generateName();
|
||||
const count = await db
|
||||
.select({ niceId: siteResources.niceId, orgId: siteResources.orgId })
|
||||
.from(siteResources)
|
||||
.where(and(eq(siteResources.niceId, name), eq(siteResources.orgId, orgId)));
|
||||
if (count.length === 0) {
|
||||
const [resourceCount, siteResourceCount] = await Promise.all([
|
||||
db
|
||||
.select({ niceId: resources.niceId, orgId: resources.orgId })
|
||||
.from(resources)
|
||||
.where(and(eq(resources.niceId, name), eq(resources.orgId, orgId))),
|
||||
db
|
||||
.select({ niceId: siteResources.niceId, orgId: siteResources.orgId })
|
||||
.from(siteResources)
|
||||
.where(and(eq(siteResources.niceId, name), eq(siteResources.orgId, orgId)))
|
||||
]);
|
||||
if (resourceCount.length === 0 && siteResourceCount.length === 0) {
|
||||
return name;
|
||||
}
|
||||
loops++;
|
||||
|
||||
@@ -73,7 +73,7 @@ function createDb() {
|
||||
|
||||
return withReplicas(
|
||||
DrizzlePostgres(primaryPool, {
|
||||
logger: process.env.NODE_ENV === "development"
|
||||
logger: process.env.QUERY_LOGGING === "true"
|
||||
}),
|
||||
replicas as any
|
||||
);
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
} from "drizzle-orm/pg-core";
|
||||
import { InferSelectModel } from "drizzle-orm";
|
||||
import { randomUUID } from "crypto";
|
||||
import { alias } from "yargs";
|
||||
|
||||
export const domains = pgTable("domains", {
|
||||
domainId: varchar("domainId").primaryKey(),
|
||||
@@ -41,6 +42,7 @@ export const orgs = pgTable("orgs", {
|
||||
orgId: varchar("orgId").primaryKey(),
|
||||
name: varchar("name").notNull(),
|
||||
subnet: varchar("subnet"),
|
||||
utilitySubnet: varchar("utilitySubnet"), // this is the subnet for utility addresses
|
||||
createdAt: text("createdAt"),
|
||||
requireTwoFactor: boolean("requireTwoFactor"),
|
||||
maxSessionLengthHours: integer("maxSessionLengthHours"),
|
||||
@@ -89,8 +91,7 @@ export const sites = pgTable("sites", {
|
||||
publicKey: varchar("publicKey"),
|
||||
lastHolePunch: bigint("lastHolePunch", { mode: "number" }),
|
||||
listenPort: integer("listenPort"),
|
||||
dockerSocketEnabled: boolean("dockerSocketEnabled").notNull().default(true),
|
||||
remoteSubnets: text("remoteSubnets") // comma-separated list of subnets that this site can access
|
||||
dockerSocketEnabled: boolean("dockerSocketEnabled").notNull().default(true)
|
||||
});
|
||||
|
||||
export const resources = pgTable("resources", {
|
||||
@@ -206,11 +207,41 @@ export const siteResources = pgTable("siteResources", {
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" }),
|
||||
niceId: varchar("niceId").notNull(),
|
||||
name: varchar("name").notNull(),
|
||||
protocol: varchar("protocol").notNull(),
|
||||
proxyPort: integer("proxyPort").notNull(),
|
||||
destinationPort: integer("destinationPort").notNull(),
|
||||
destinationIp: varchar("destinationIp").notNull(),
|
||||
enabled: boolean("enabled").notNull().default(true)
|
||||
mode: varchar("mode").notNull(), // "host" | "cidr" | "port"
|
||||
protocol: varchar("protocol"), // only for port mode
|
||||
proxyPort: integer("proxyPort"), // only for port mode
|
||||
destinationPort: integer("destinationPort"), // only for port mode
|
||||
destination: varchar("destination").notNull(), // ip, cidr, hostname; validate against the mode
|
||||
enabled: boolean("enabled").notNull().default(true),
|
||||
alias: varchar("alias"),
|
||||
aliasAddress: varchar("aliasAddress")
|
||||
});
|
||||
|
||||
export const clientSiteResources = pgTable("clientSiteResources", {
|
||||
clientId: integer("clientId")
|
||||
.notNull()
|
||||
.references(() => clients.clientId, { onDelete: "cascade" }),
|
||||
siteResourceId: integer("siteResourceId")
|
||||
.notNull()
|
||||
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const roleSiteResources = pgTable("roleSiteResources", {
|
||||
roleId: integer("roleId")
|
||||
.notNull()
|
||||
.references(() => roles.roleId, { onDelete: "cascade" }),
|
||||
siteResourceId: integer("siteResourceId")
|
||||
.notNull()
|
||||
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const userSiteResources = pgTable("userSiteResources", {
|
||||
userId: varchar("userId")
|
||||
.notNull()
|
||||
.references(() => users.userId, { onDelete: "cascade" }),
|
||||
siteResourceId: integer("siteResourceId")
|
||||
.notNull()
|
||||
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const users = pgTable("user", {
|
||||
@@ -258,7 +289,8 @@ export const sessions = pgTable("session", {
|
||||
.notNull()
|
||||
.references(() => users.userId, { onDelete: "cascade" }),
|
||||
expiresAt: bigint("expiresAt", { mode: "number" }).notNull(),
|
||||
issuedAt: bigint("issuedAt", { mode: "number" })
|
||||
issuedAt: bigint("issuedAt", { mode: "number" }),
|
||||
deviceAuthUsed: boolean("deviceAuthUsed").notNull().default(false)
|
||||
});
|
||||
|
||||
export const newtSessions = pgTable("newtSession", {
|
||||
@@ -600,7 +632,7 @@ export const idpOrg = pgTable("idpOrg", {
|
||||
});
|
||||
|
||||
export const clients = pgTable("clients", {
|
||||
clientId: serial("id").primaryKey(),
|
||||
clientId: serial("clientId").primaryKey(),
|
||||
orgId: varchar("orgId")
|
||||
.references(() => orgs.orgId, {
|
||||
onDelete: "cascade"
|
||||
@@ -609,6 +641,11 @@ export const clients = pgTable("clients", {
|
||||
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
|
||||
onDelete: "set null"
|
||||
}),
|
||||
userId: text("userId").references(() => users.userId, {
|
||||
// optionally tied to a user and in this case delete when the user deletes
|
||||
onDelete: "cascade"
|
||||
}),
|
||||
olmId: text("olmId"), // to lock it to a specific olm optionally
|
||||
name: varchar("name").notNull(),
|
||||
pubKey: varchar("pubKey"),
|
||||
subnet: varchar("subnet").notNull(),
|
||||
@@ -623,23 +660,40 @@ export const clients = pgTable("clients", {
|
||||
maxConnections: integer("maxConnections")
|
||||
});
|
||||
|
||||
export const clientSites = pgTable("clientSites", {
|
||||
clientId: integer("clientId")
|
||||
.notNull()
|
||||
.references(() => clients.clientId, { onDelete: "cascade" }),
|
||||
siteId: integer("siteId")
|
||||
.notNull()
|
||||
.references(() => sites.siteId, { onDelete: "cascade" }),
|
||||
isRelayed: boolean("isRelayed").notNull().default(false),
|
||||
endpoint: varchar("endpoint")
|
||||
});
|
||||
export const clientSitesAssociationsCache = pgTable(
|
||||
"clientSitesAssociationsCache",
|
||||
{
|
||||
clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message
|
||||
.notNull(),
|
||||
siteId: integer("siteId").notNull(),
|
||||
isRelayed: boolean("isRelayed").notNull().default(false),
|
||||
endpoint: varchar("endpoint"),
|
||||
publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
|
||||
}
|
||||
);
|
||||
|
||||
export const clientSiteResourcesAssociationsCache = pgTable(
|
||||
"clientSiteResourcesAssociationsCache",
|
||||
{
|
||||
clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message
|
||||
.notNull(),
|
||||
siteResourceId: integer("siteResourceId").notNull()
|
||||
}
|
||||
);
|
||||
|
||||
export const olms = pgTable("olms", {
|
||||
olmId: varchar("id").primaryKey(),
|
||||
secretHash: varchar("secretHash").notNull(),
|
||||
dateCreated: varchar("dateCreated").notNull(),
|
||||
version: text("version"),
|
||||
agent: text("agent"),
|
||||
name: varchar("name"),
|
||||
clientId: integer("clientId").references(() => clients.clientId, {
|
||||
// we will switch this depending on the current org it wants to connect to
|
||||
onDelete: "set null"
|
||||
}),
|
||||
userId: text("userId").references(() => users.userId, {
|
||||
// optionally tied to a user and in this case delete when the user deletes
|
||||
onDelete: "cascade"
|
||||
})
|
||||
});
|
||||
@@ -755,6 +809,21 @@ export const requestAuditLog = pgTable(
|
||||
]
|
||||
);
|
||||
|
||||
export const deviceWebAuthCodes = pgTable("deviceWebAuthCodes", {
|
||||
codeId: serial("codeId").primaryKey(),
|
||||
code: text("code").notNull().unique(),
|
||||
ip: text("ip"),
|
||||
city: text("city"),
|
||||
deviceName: text("deviceName"),
|
||||
applicationName: text("applicationName").notNull(),
|
||||
expiresAt: bigint("expiresAt", { mode: "number" }).notNull(),
|
||||
createdAt: bigint("createdAt", { mode: "number" }).notNull(),
|
||||
verified: boolean("verified").notNull().default(false),
|
||||
userId: varchar("userId").references(() => users.userId, {
|
||||
onDelete: "cascade"
|
||||
})
|
||||
});
|
||||
|
||||
export type Org = InferSelectModel<typeof orgs>;
|
||||
export type User = InferSelectModel<typeof users>;
|
||||
export type Site = InferSelectModel<typeof sites>;
|
||||
@@ -795,7 +864,7 @@ export type ApiKey = InferSelectModel<typeof apiKeys>;
|
||||
export type ApiKeyAction = InferSelectModel<typeof apiKeyActions>;
|
||||
export type ApiKeyOrg = InferSelectModel<typeof apiKeyOrg>;
|
||||
export type Client = InferSelectModel<typeof clients>;
|
||||
export type ClientSite = InferSelectModel<typeof clientSites>;
|
||||
export type ClientSite = InferSelectModel<typeof clientSitesAssociationsCache>;
|
||||
export type Olm = InferSelectModel<typeof olms>;
|
||||
export type OlmSession = InferSelectModel<typeof olmSessions>;
|
||||
export type UserClient = InferSelectModel<typeof userClients>;
|
||||
@@ -810,4 +879,5 @@ export type Blueprint = InferSelectModel<typeof blueprints>;
|
||||
export type LicenseKey = InferSelectModel<typeof licenseKey>;
|
||||
export type SecurityKey = InferSelectModel<typeof securityKeys>;
|
||||
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
|
||||
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
|
||||
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import {
|
||||
sqliteTable,
|
||||
integer,
|
||||
text,
|
||||
real,
|
||||
index
|
||||
} from "drizzle-orm/sqlite-core";
|
||||
import { InferSelectModel } from "drizzle-orm";
|
||||
import { domains, orgs, targets, users, exitNodes, sessions } from "./schema";
|
||||
import { metadata } from "@app/app/[orgId]/settings/layout";
|
||||
import {
|
||||
index,
|
||||
integer,
|
||||
real,
|
||||
sqliteTable,
|
||||
text
|
||||
} from "drizzle-orm/sqlite-core";
|
||||
import { domains, exitNodes, orgs, sessions, users } from "./schema";
|
||||
|
||||
export const certificates = sqliteTable("certificates", {
|
||||
certId: integer("certId").primaryKey({ autoIncrement: true }),
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
index,
|
||||
uniqueIndex
|
||||
} from "drizzle-orm/sqlite-core";
|
||||
import { boolean } from "yargs";
|
||||
import { no } from "zod/v4/locales";
|
||||
|
||||
export const domains = sqliteTable("domains", {
|
||||
domainId: text("domainId").primaryKey(),
|
||||
@@ -39,6 +39,7 @@ export const orgs = sqliteTable("orgs", {
|
||||
orgId: text("orgId").primaryKey(),
|
||||
name: text("name").notNull(),
|
||||
subnet: text("subnet"),
|
||||
utilitySubnet: text("utilitySubnet"), // this is the subnet for utility addresses
|
||||
createdAt: text("createdAt"),
|
||||
requireTwoFactor: integer("requireTwoFactor", { mode: "boolean" }),
|
||||
maxSessionLengthHours: integer("maxSessionLengthHours"), // hours
|
||||
@@ -100,8 +101,7 @@ export const sites = sqliteTable("sites", {
|
||||
listenPort: integer("listenPort"),
|
||||
dockerSocketEnabled: integer("dockerSocketEnabled", { mode: "boolean" })
|
||||
.notNull()
|
||||
.default(true),
|
||||
remoteSubnets: text("remoteSubnets") // comma-separated list of subnets that this site can access
|
||||
.default(true)
|
||||
});
|
||||
|
||||
export const resources = sqliteTable("resources", {
|
||||
@@ -202,7 +202,7 @@ export const targetHealthCheck = sqliteTable("targetHealthCheck", {
|
||||
hcMethod: text("hcMethod").default("GET"),
|
||||
hcStatus: integer("hcStatus"), // http code
|
||||
hcHealth: text("hcHealth").default("unknown"), // "unknown", "healthy", "unhealthy"
|
||||
hcTlsServerName: text("hcTlsServerName"),
|
||||
hcTlsServerName: text("hcTlsServerName")
|
||||
});
|
||||
|
||||
export const exitNodes = sqliteTable("exitNodes", {
|
||||
@@ -233,11 +233,41 @@ export const siteResources = sqliteTable("siteResources", {
|
||||
.references(() => orgs.orgId, { onDelete: "cascade" }),
|
||||
niceId: text("niceId").notNull(),
|
||||
name: text("name").notNull(),
|
||||
protocol: text("protocol").notNull(),
|
||||
proxyPort: integer("proxyPort").notNull(),
|
||||
destinationPort: integer("destinationPort").notNull(),
|
||||
destinationIp: text("destinationIp").notNull(),
|
||||
enabled: integer("enabled", { mode: "boolean" }).notNull().default(true)
|
||||
mode: text("mode").notNull(), // "host" | "cidr" | "port"
|
||||
protocol: text("protocol"), // only for port mode
|
||||
proxyPort: integer("proxyPort"), // only for port mode
|
||||
destinationPort: integer("destinationPort"), // only for port mode
|
||||
destination: text("destination").notNull(), // ip, cidr, hostname
|
||||
enabled: integer("enabled", { mode: "boolean" }).notNull().default(true),
|
||||
alias: text("alias"),
|
||||
aliasAddress: text("aliasAddress")
|
||||
});
|
||||
|
||||
export const clientSiteResources = sqliteTable("clientSiteResources", {
|
||||
clientId: integer("clientId")
|
||||
.notNull()
|
||||
.references(() => clients.clientId, { onDelete: "cascade" }),
|
||||
siteResourceId: integer("siteResourceId")
|
||||
.notNull()
|
||||
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const roleSiteResources = sqliteTable("roleSiteResources", {
|
||||
roleId: integer("roleId")
|
||||
.notNull()
|
||||
.references(() => roles.roleId, { onDelete: "cascade" }),
|
||||
siteResourceId: integer("siteResourceId")
|
||||
.notNull()
|
||||
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const userSiteResources = sqliteTable("userSiteResources", {
|
||||
userId: text("userId")
|
||||
.notNull()
|
||||
.references(() => users.userId, { onDelete: "cascade" }),
|
||||
siteResourceId: integer("siteResourceId")
|
||||
.notNull()
|
||||
.references(() => siteResources.siteResourceId, { onDelete: "cascade" })
|
||||
});
|
||||
|
||||
export const users = sqliteTable("user", {
|
||||
@@ -313,7 +343,7 @@ export const newts = sqliteTable("newt", {
|
||||
});
|
||||
|
||||
export const clients = sqliteTable("clients", {
|
||||
clientId: integer("id").primaryKey({ autoIncrement: true }),
|
||||
clientId: integer("clientId").primaryKey({ autoIncrement: true }),
|
||||
orgId: text("orgId")
|
||||
.references(() => orgs.orgId, {
|
||||
onDelete: "cascade"
|
||||
@@ -322,8 +352,14 @@ export const clients = sqliteTable("clients", {
|
||||
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
|
||||
onDelete: "set null"
|
||||
}),
|
||||
userId: text("userId").references(() => users.userId, {
|
||||
// optionally tied to a user and in this case delete when the user deletes
|
||||
onDelete: "cascade"
|
||||
}),
|
||||
|
||||
name: text("name").notNull(),
|
||||
pubKey: text("pubKey"),
|
||||
olmId: text("olmId"), // to lock it to a specific olm optionally
|
||||
subnet: text("subnet").notNull(),
|
||||
megabytesIn: integer("bytesIn"),
|
||||
megabytesOut: integer("bytesOut"),
|
||||
@@ -335,25 +371,42 @@ export const clients = sqliteTable("clients", {
|
||||
lastHolePunch: integer("lastHolePunch")
|
||||
});
|
||||
|
||||
export const clientSites = sqliteTable("clientSites", {
|
||||
clientId: integer("clientId")
|
||||
.notNull()
|
||||
.references(() => clients.clientId, { onDelete: "cascade" }),
|
||||
siteId: integer("siteId")
|
||||
.notNull()
|
||||
.references(() => sites.siteId, { onDelete: "cascade" }),
|
||||
isRelayed: integer("isRelayed", { mode: "boolean" })
|
||||
.notNull()
|
||||
.default(false),
|
||||
endpoint: text("endpoint")
|
||||
});
|
||||
export const clientSitesAssociationsCache = sqliteTable(
|
||||
"clientSitesAssociationsCache",
|
||||
{
|
||||
clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message
|
||||
.notNull(),
|
||||
siteId: integer("siteId").notNull(),
|
||||
isRelayed: integer("isRelayed", { mode: "boolean" })
|
||||
.notNull()
|
||||
.default(false),
|
||||
endpoint: text("endpoint"),
|
||||
publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
|
||||
}
|
||||
);
|
||||
|
||||
export const clientSiteResourcesAssociationsCache = sqliteTable(
|
||||
"clientSiteResourcesAssociationsCache",
|
||||
{
|
||||
clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message
|
||||
.notNull(),
|
||||
siteResourceId: integer("siteResourceId").notNull()
|
||||
}
|
||||
);
|
||||
|
||||
export const olms = sqliteTable("olms", {
|
||||
olmId: text("id").primaryKey(),
|
||||
secretHash: text("secretHash").notNull(),
|
||||
dateCreated: text("dateCreated").notNull(),
|
||||
version: text("version"),
|
||||
agent: text("agent"),
|
||||
name: text("name"),
|
||||
clientId: integer("clientId").references(() => clients.clientId, {
|
||||
// we will switch this depending on the current org it wants to connect to
|
||||
onDelete: "set null"
|
||||
}),
|
||||
userId: text("userId").references(() => users.userId, {
|
||||
// optionally tied to a user and in this case delete when the user deletes
|
||||
onDelete: "cascade"
|
||||
})
|
||||
});
|
||||
@@ -372,7 +425,10 @@ export const sessions = sqliteTable("session", {
|
||||
.notNull()
|
||||
.references(() => users.userId, { onDelete: "cascade" }),
|
||||
expiresAt: integer("expiresAt").notNull(),
|
||||
issuedAt: integer("issuedAt")
|
||||
issuedAt: integer("issuedAt"),
|
||||
deviceAuthUsed: integer("deviceAuthUsed", { mode: "boolean" })
|
||||
.notNull()
|
||||
.default(false)
|
||||
});
|
||||
|
||||
export const newtSessions = sqliteTable("newtSession", {
|
||||
@@ -809,6 +865,21 @@ export const requestAuditLog = sqliteTable(
|
||||
]
|
||||
);
|
||||
|
||||
export const deviceWebAuthCodes = sqliteTable("deviceWebAuthCodes", {
|
||||
codeId: integer("codeId").primaryKey({ autoIncrement: true }),
|
||||
code: text("code").notNull().unique(),
|
||||
ip: text("ip"),
|
||||
city: text("city"),
|
||||
deviceName: text("deviceName"),
|
||||
applicationName: text("applicationName").notNull(),
|
||||
expiresAt: integer("expiresAt").notNull(),
|
||||
createdAt: integer("createdAt").notNull(),
|
||||
verified: integer("verified", { mode: "boolean" }).notNull().default(false),
|
||||
userId: text("userId").references(() => users.userId, {
|
||||
onDelete: "cascade"
|
||||
})
|
||||
});
|
||||
|
||||
export type Org = InferSelectModel<typeof orgs>;
|
||||
export type User = InferSelectModel<typeof users>;
|
||||
export type Site = InferSelectModel<typeof sites>;
|
||||
@@ -847,7 +918,7 @@ export type ResourceRule = InferSelectModel<typeof resourceRules>;
|
||||
export type Domain = InferSelectModel<typeof domains>;
|
||||
export type DnsRecord = InferSelectModel<typeof dnsRecords>;
|
||||
export type Client = InferSelectModel<typeof clients>;
|
||||
export type ClientSite = InferSelectModel<typeof clientSites>;
|
||||
export type ClientSite = InferSelectModel<typeof clientSitesAssociationsCache>;
|
||||
export type RoleClient = InferSelectModel<typeof roleClients>;
|
||||
export type UserClient = InferSelectModel<typeof userClients>;
|
||||
export type SupporterKey = InferSelectModel<typeof supporterKey>;
|
||||
@@ -866,3 +937,4 @@ export type LicenseKey = InferSelectModel<typeof licenseKey>;
|
||||
export type SecurityKey = InferSelectModel<typeof securityKeys>;
|
||||
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
|
||||
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
|
||||
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
ApiKeyOrg,
|
||||
RemoteExitNode,
|
||||
Session,
|
||||
SiteResource,
|
||||
User,
|
||||
UserOrg
|
||||
} from "@server/db";
|
||||
@@ -77,6 +78,8 @@ declare global {
|
||||
userOrgId?: string;
|
||||
userOrgIds?: string[];
|
||||
remoteExitNode?: RemoteExitNode;
|
||||
siteResource?: SiteResource;
|
||||
orgPolicyAllowed?: boolean;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,19 +122,17 @@ export async function applyBlueprint({
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (site) {
|
||||
logger.debug(
|
||||
`Updating client resource ${result.resource.siteResourceId} on site ${site.sites.siteId}`
|
||||
);
|
||||
logger.debug(
|
||||
`Updating client resource ${result.resource.siteResourceId} on site ${site.sites.siteId}`
|
||||
);
|
||||
|
||||
await addClientTargets(
|
||||
site.newt.newtId,
|
||||
result.resource.destinationIp,
|
||||
result.resource.destinationPort,
|
||||
result.resource.protocol,
|
||||
result.resource.proxyPort
|
||||
);
|
||||
}
|
||||
// await addClientTargets(
|
||||
// site.newt.newtId,
|
||||
// result.resource.destination,
|
||||
// result.resource.destinationPort,
|
||||
// result.resource.protocol,
|
||||
// result.resource.proxyPort
|
||||
// );
|
||||
}
|
||||
|
||||
blueprintSucceeded = true;
|
||||
|
||||
@@ -75,8 +75,9 @@ export async function updateClientResources(
|
||||
.set({
|
||||
name: resourceData.name || resourceNiceId,
|
||||
siteId: site.siteId,
|
||||
mode: "port",
|
||||
proxyPort: resourceData["proxy-port"]!,
|
||||
destinationIp: resourceData.hostname,
|
||||
destination: resourceData.hostname,
|
||||
destinationPort: resourceData["internal-port"],
|
||||
protocol: resourceData.protocol
|
||||
})
|
||||
@@ -98,8 +99,9 @@ export async function updateClientResources(
|
||||
siteId: site.siteId,
|
||||
niceId: resourceNiceId,
|
||||
name: resourceData.name || resourceNiceId,
|
||||
mode: "port",
|
||||
proxyPort: resourceData["proxy-port"]!,
|
||||
destinationIp: resourceData.hostname,
|
||||
destination: resourceData.hostname,
|
||||
destinationPort: resourceData["internal-port"],
|
||||
protocol: resourceData.protocol
|
||||
})
|
||||
|
||||
@@ -221,6 +221,7 @@ export async function updateProxyResources(
|
||||
domainId: domain ? domain.domainId : null,
|
||||
enabled: resourceEnabled,
|
||||
sso: resourceData.auth?.["sso-enabled"] || false,
|
||||
skipToIdpId: resourceData.auth?.["auto-login-idp"] || null,
|
||||
ssl: resourceSsl,
|
||||
setHostHeader: resourceData["host-header"] || null,
|
||||
tlsServerName: resourceData["tls-server-name"] || null,
|
||||
@@ -610,6 +611,7 @@ export async function updateProxyResources(
|
||||
domainId: domain ? domain.domainId : null,
|
||||
enabled: resourceEnabled,
|
||||
sso: resourceData.auth?.["sso-enabled"] || false,
|
||||
skipToIdpId: resourceData.auth?.["auto-login-idp"] || null,
|
||||
setHostHeader: resourceData["host-header"] || null,
|
||||
tlsServerName: resourceData["tls-server-name"] || null,
|
||||
ssl: resourceSsl,
|
||||
@@ -789,10 +791,6 @@ async function syncRoleResources(
|
||||
.where(eq(roleResources.resourceId, resourceId));
|
||||
|
||||
for (const roleName of ssoRoles) {
|
||||
if (roleName === "Admin") {
|
||||
continue; // never add admin access
|
||||
}
|
||||
|
||||
const [role] = await trx
|
||||
.select()
|
||||
.from(roles)
|
||||
@@ -803,6 +801,10 @@ async function syncRoleResources(
|
||||
throw new Error(`Role not found: ${roleName} in org ${orgId}`);
|
||||
}
|
||||
|
||||
if (role.isAdmin) {
|
||||
continue; // never add admin access
|
||||
}
|
||||
|
||||
const existingRoleResource = existingRoleResources.find(
|
||||
(rr) => rr.roleId === role.roleId
|
||||
);
|
||||
|
||||
@@ -59,6 +59,7 @@ export const AuthSchema = z.object({
|
||||
}),
|
||||
"sso-users": z.array(z.email()).optional().default([]),
|
||||
"whitelist-users": z.array(z.email()).optional().default([]),
|
||||
"auto-login-idp": z.int().positive().optional(),
|
||||
});
|
||||
|
||||
export const RuleSchema = z.object({
|
||||
|
||||
286
server/lib/calculateUserClientsForOrgs.ts
Normal file
286
server/lib/calculateUserClientsForOrgs.ts
Normal file
@@ -0,0 +1,286 @@
|
||||
import {
|
||||
clients,
|
||||
db,
|
||||
olms,
|
||||
orgs,
|
||||
roleClients,
|
||||
roles,
|
||||
userClients,
|
||||
userOrgs,
|
||||
Transaction
|
||||
} from "@server/db";
|
||||
import { eq, and, notInArray } from "drizzle-orm";
|
||||
import { listExitNodes } from "#dynamic/lib/exitNodes";
|
||||
import { getNextAvailableClientSubnet } from "@server/lib/ip";
|
||||
import logger from "@server/logger";
|
||||
import { rebuildClientAssociationsFromClient } from "./rebuildClientAssociations";
|
||||
import { sendTerminateClient } from "@server/routers/client/terminate";
|
||||
|
||||
export async function calculateUserClientsForOrgs(
|
||||
userId: string,
|
||||
trx?: Transaction
|
||||
): Promise<void> {
|
||||
const execute = async (transaction: Transaction) => {
|
||||
// Get all OLMs for this user
|
||||
const userOlms = await transaction
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.userId, userId));
|
||||
|
||||
if (userOlms.length === 0) {
|
||||
// No OLMs for this user, but we should still clean up any orphaned clients
|
||||
await cleanupOrphanedClients(userId, transaction);
|
||||
return;
|
||||
}
|
||||
|
||||
// Get all user orgs
|
||||
const allUserOrgs = await transaction
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(eq(userOrgs.userId, userId));
|
||||
|
||||
const userOrgIds = allUserOrgs.map((uo) => uo.orgId);
|
||||
|
||||
// For each OLM, ensure there's a client in each org the user is in
|
||||
for (const olm of userOlms) {
|
||||
for (const userOrg of allUserOrgs) {
|
||||
const orgId = userOrg.orgId;
|
||||
|
||||
const [org] = await transaction
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, orgId));
|
||||
|
||||
if (!org) {
|
||||
logger.warn(
|
||||
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): org not found`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!org.subnet) {
|
||||
logger.warn(
|
||||
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): org has no subnet configured`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get admin role for this org (needed for access grants)
|
||||
const [adminRole] = await transaction
|
||||
.select()
|
||||
.from(roles)
|
||||
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
|
||||
.limit(1);
|
||||
|
||||
if (!adminRole) {
|
||||
logger.warn(
|
||||
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): no admin role found`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if a client already exists for this OLM+user+org combination
|
||||
const [existingClient] = await transaction
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(
|
||||
and(
|
||||
eq(clients.userId, userId),
|
||||
eq(clients.orgId, orgId),
|
||||
eq(clients.olmId, olm.olmId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (existingClient) {
|
||||
// Ensure admin role has access to the client
|
||||
const [existingRoleClient] = await transaction
|
||||
.select()
|
||||
.from(roleClients)
|
||||
.where(
|
||||
and(
|
||||
eq(roleClients.roleId, adminRole.roleId),
|
||||
eq(
|
||||
roleClients.clientId,
|
||||
existingClient.clientId
|
||||
)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!existingRoleClient) {
|
||||
await transaction.insert(roleClients).values({
|
||||
roleId: adminRole.roleId,
|
||||
clientId: existingClient.clientId
|
||||
});
|
||||
logger.debug(
|
||||
`Granted admin role access to existing client ${existingClient.clientId} for OLM ${olm.olmId} in org ${orgId} (user ${userId})`
|
||||
);
|
||||
}
|
||||
|
||||
// Ensure user has access to the client
|
||||
const [existingUserClient] = await transaction
|
||||
.select()
|
||||
.from(userClients)
|
||||
.where(
|
||||
and(
|
||||
eq(userClients.userId, userId),
|
||||
eq(
|
||||
userClients.clientId,
|
||||
existingClient.clientId
|
||||
)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!existingUserClient) {
|
||||
await transaction.insert(userClients).values({
|
||||
userId,
|
||||
clientId: existingClient.clientId
|
||||
});
|
||||
logger.debug(
|
||||
`Granted user access to existing client ${existingClient.clientId} for OLM ${olm.olmId} in org ${orgId} (user ${userId})`
|
||||
);
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`Client already exists for OLM ${olm.olmId} in org ${orgId} (user ${userId}), skipping creation`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get exit nodes for this org
|
||||
const exitNodesList = await listExitNodes(orgId);
|
||||
|
||||
if (exitNodesList.length === 0) {
|
||||
logger.warn(
|
||||
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): no exit nodes found`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const randomExitNode =
|
||||
exitNodesList[
|
||||
Math.floor(Math.random() * exitNodesList.length)
|
||||
];
|
||||
|
||||
// Get next available subnet
|
||||
const newSubnet = await getNextAvailableClientSubnet(orgId);
|
||||
if (!newSubnet) {
|
||||
logger.warn(
|
||||
`Skipping org ${orgId} for OLM ${olm.olmId} (user ${userId}): no available subnet found`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const subnet = newSubnet.split("/")[0];
|
||||
const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`;
|
||||
|
||||
// Create the client
|
||||
const [newClient] = await transaction
|
||||
.insert(clients)
|
||||
.values({
|
||||
userId,
|
||||
orgId: userOrg.orgId,
|
||||
exitNodeId: randomExitNode.exitNodeId,
|
||||
name: olm.name || "User Client",
|
||||
subnet: updatedSubnet,
|
||||
olmId: olm.olmId,
|
||||
type: "olm"
|
||||
})
|
||||
.returning();
|
||||
|
||||
await rebuildClientAssociationsFromClient(
|
||||
newClient,
|
||||
transaction
|
||||
);
|
||||
|
||||
// Grant admin role access to the client
|
||||
await transaction.insert(roleClients).values({
|
||||
roleId: adminRole.roleId,
|
||||
clientId: newClient.clientId
|
||||
});
|
||||
|
||||
// Grant user access to the client
|
||||
await transaction.insert(userClients).values({
|
||||
userId,
|
||||
clientId: newClient.clientId
|
||||
});
|
||||
|
||||
logger.debug(
|
||||
`Created client for OLM ${olm.olmId} in org ${orgId} (user ${userId}) with access granted to admin role and user`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up clients in orgs the user is no longer in
|
||||
await cleanupOrphanedClients(userId, transaction, userOrgIds);
|
||||
};
|
||||
|
||||
if (trx) {
|
||||
// Use provided transaction
|
||||
await execute(trx);
|
||||
} else {
|
||||
// Create new transaction
|
||||
await db.transaction(async (transaction) => {
|
||||
await execute(transaction);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async function cleanupOrphanedClients(
|
||||
userId: string,
|
||||
trx: Transaction,
|
||||
userOrgIds: string[] = []
|
||||
): Promise<void> {
|
||||
// Find all OLM clients for this user that should be deleted
|
||||
// If userOrgIds is empty, delete all OLM clients (user has no orgs)
|
||||
// If userOrgIds has values, delete clients in orgs they're not in
|
||||
const clientsToDelete = await trx
|
||||
.select({ clientId: clients.clientId })
|
||||
.from(clients)
|
||||
.where(
|
||||
userOrgIds.length > 0
|
||||
? and(
|
||||
eq(clients.userId, userId),
|
||||
notInArray(clients.orgId, userOrgIds)
|
||||
)
|
||||
: and(eq(clients.userId, userId))
|
||||
);
|
||||
|
||||
if (clientsToDelete.length > 0) {
|
||||
const deletedClients = await trx
|
||||
.delete(clients)
|
||||
.where(
|
||||
userOrgIds.length > 0
|
||||
? and(
|
||||
eq(clients.userId, userId),
|
||||
notInArray(clients.orgId, userOrgIds)
|
||||
)
|
||||
: and(eq(clients.userId, userId))
|
||||
)
|
||||
.returning();
|
||||
|
||||
// Rebuild associations for each deleted client to clean up related data
|
||||
for (const deletedClient of deletedClients) {
|
||||
await rebuildClientAssociationsFromClient(deletedClient, trx);
|
||||
|
||||
if (deletedClient.olmId) {
|
||||
await sendTerminateClient(
|
||||
deletedClient.clientId,
|
||||
deletedClient.olmId
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (userOrgIds.length === 0) {
|
||||
logger.debug(
|
||||
`Deleted all ${clientsToDelete.length} OLM client(s) for user ${userId} (user has no orgs)`
|
||||
);
|
||||
} else {
|
||||
logger.debug(
|
||||
`Deleted ${clientsToDelete.length} orphaned OLM client(s) for user ${userId} in orgs they're no longer in`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -85,10 +85,6 @@ export class Config {
|
||||
? "true"
|
||||
: "false";
|
||||
|
||||
process.env.FLAGS_ENABLE_CLIENTS = parsedConfig.flags?.enable_clients
|
||||
? "true"
|
||||
: "false";
|
||||
|
||||
process.env.PRODUCT_UPDATES_NOTIFICATION_ENABLED = parsedConfig.app
|
||||
.notifications.product_updates
|
||||
? "true"
|
||||
|
||||
@@ -2,7 +2,7 @@ import path from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
|
||||
// This is a placeholder value replaced by the build process
|
||||
export const APP_VERSION = "1.12.1";
|
||||
export const APP_VERSION = "1.12.3";
|
||||
|
||||
export const __FILENAME = fileURLToPath(import.meta.url);
|
||||
export const __DIRNAME = path.dirname(__FILENAME);
|
||||
|
||||
@@ -18,6 +18,7 @@ import { defaultRoleAllowedActions } from "@server/routers/role";
|
||||
import { FeatureId, limitsService, sandboxLimitSet } from "@server/lib/billing";
|
||||
import { createCustomer } from "#dynamic/lib/billing";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
export async function createUserAccountOrg(
|
||||
userId: string,
|
||||
@@ -76,6 +77,8 @@ export async function createUserAccountOrg(
|
||||
.from(domains)
|
||||
.where(eq(domains.configManaged, true));
|
||||
|
||||
const utilitySubnet = config.getRawConfig().orgs.utility_subnet_group;
|
||||
|
||||
const newOrg = await trx
|
||||
.insert(orgs)
|
||||
.values({
|
||||
@@ -83,6 +86,7 @@ export async function createUserAccountOrg(
|
||||
name,
|
||||
// subnet
|
||||
subnet: "100.90.128.0/24", // TODO: this should not be hardcoded - or can it be the same in all orgs?
|
||||
utilitySubnet: utilitySubnet,
|
||||
createdAt: new Date().toISOString()
|
||||
})
|
||||
.returning();
|
||||
|
||||
170
server/lib/ip.ts
170
server/lib/ip.ts
@@ -1,7 +1,15 @@
|
||||
import { db } from "@server/db";
|
||||
import {
|
||||
clientSitesAssociationsCache,
|
||||
db,
|
||||
SiteResource,
|
||||
siteResources,
|
||||
Transaction
|
||||
} from "@server/db";
|
||||
import { clients, orgs, sites } from "@server/db";
|
||||
import { and, eq, isNotNull } from "drizzle-orm";
|
||||
import config from "@server/lib/config";
|
||||
import z from "zod";
|
||||
import logger from "@server/logger";
|
||||
|
||||
interface IPRange {
|
||||
start: bigint;
|
||||
@@ -279,6 +287,56 @@ export async function getNextAvailableClientSubnet(
|
||||
return subnet;
|
||||
}
|
||||
|
||||
export async function getNextAvailableAliasAddress(
|
||||
orgId: string
|
||||
): Promise<string> {
|
||||
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
|
||||
|
||||
if (!org) {
|
||||
throw new Error(`Organization with ID ${orgId} not found`);
|
||||
}
|
||||
|
||||
if (!org.subnet) {
|
||||
throw new Error(`Organization with ID ${orgId} has no subnet defined`);
|
||||
}
|
||||
|
||||
if (!org.utilitySubnet) {
|
||||
throw new Error(
|
||||
`Organization with ID ${orgId} has no utility subnet defined`
|
||||
);
|
||||
}
|
||||
|
||||
const existingAddresses = await db
|
||||
.select({
|
||||
aliasAddress: siteResources.aliasAddress
|
||||
})
|
||||
.from(siteResources)
|
||||
.where(
|
||||
and(
|
||||
isNotNull(siteResources.aliasAddress),
|
||||
eq(siteResources.orgId, orgId)
|
||||
)
|
||||
);
|
||||
|
||||
const addresses = [
|
||||
...existingAddresses.map(
|
||||
(site) => `${site.aliasAddress?.split("/")[0]}/32`
|
||||
),
|
||||
// reserve a /29 for the dns server and other stuff
|
||||
`${org.utilitySubnet.split("/")[0]}/29`
|
||||
].filter((address) => address !== null) as string[];
|
||||
|
||||
let subnet = findNextAvailableCidr(addresses, 32, org.utilitySubnet);
|
||||
if (!subnet) {
|
||||
throw new Error("No available subnets remaining in space");
|
||||
}
|
||||
|
||||
// remove the cidr
|
||||
subnet = subnet.split("/")[0];
|
||||
|
||||
return subnet;
|
||||
}
|
||||
|
||||
export async function getNextAvailableOrgSubnet(): Promise<string> {
|
||||
const existingAddresses = await db
|
||||
.select({
|
||||
@@ -300,3 +358,113 @@ export async function getNextAvailableOrgSubnet(): Promise<string> {
|
||||
|
||||
return subnet;
|
||||
}
|
||||
|
||||
export function generateRemoteSubnets(allSiteResources: SiteResource[]): string[] {
|
||||
const remoteSubnets = allSiteResources
|
||||
.filter((sr) => {
|
||||
if (sr.mode === "cidr") return true;
|
||||
if (sr.mode === "host") {
|
||||
// check if its a valid IP using zod
|
||||
const ipSchema = z.union([z.ipv4(), z.ipv6()]);
|
||||
const parseResult = ipSchema.safeParse(sr.destination);
|
||||
return parseResult.success;
|
||||
}
|
||||
return false;
|
||||
})
|
||||
.map((sr) => {
|
||||
if (sr.mode === "cidr") return sr.destination;
|
||||
if (sr.mode === "host") {
|
||||
return `${sr.destination}/32`;
|
||||
}
|
||||
return ""; // This should never be reached due to filtering, but satisfies TypeScript
|
||||
})
|
||||
.filter((subnet) => subnet !== ""); // Remove empty strings just to be safe
|
||||
// remove duplicates
|
||||
return Array.from(new Set(remoteSubnets));
|
||||
}
|
||||
|
||||
export type Alias = { alias: string | null; aliasAddress: string | null };
|
||||
|
||||
export function generateAliasConfig(allSiteResources: SiteResource[]): Alias[] {
|
||||
let aliasConfigs = allSiteResources
|
||||
.filter((sr) => sr.alias && sr.aliasAddress && sr.mode == "host")
|
||||
.map((sr) => ({
|
||||
alias: sr.alias,
|
||||
aliasAddress: sr.aliasAddress
|
||||
}));
|
||||
return aliasConfigs;
|
||||
}
|
||||
|
||||
export type SubnetProxyTarget = {
|
||||
sourcePrefix: string; // must be a cidr
|
||||
destPrefix: string; // must be a cidr
|
||||
rewriteTo?: string; // must be a cidr
|
||||
portRange?: {
|
||||
min: number;
|
||||
max: number;
|
||||
}[];
|
||||
};
|
||||
|
||||
export function generateSubnetProxyTargets(
|
||||
siteResource: SiteResource,
|
||||
clients: {
|
||||
clientId: number;
|
||||
pubKey: string | null;
|
||||
subnet: string | null;
|
||||
}[]
|
||||
): SubnetProxyTarget[] {
|
||||
const targets: SubnetProxyTarget[] = [];
|
||||
|
||||
if (clients.length === 0) {
|
||||
logger.debug(
|
||||
`No clients have access to site resource ${siteResource.siteResourceId}, skipping target generation.`
|
||||
);
|
||||
return [];
|
||||
}
|
||||
|
||||
for (const clientSite of clients) {
|
||||
if (!clientSite.subnet) {
|
||||
logger.debug(
|
||||
`Client ${clientSite.clientId} has no subnet, skipping for site resource ${siteResource.siteResourceId}.`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const clientPrefix = `${clientSite.subnet.split("/")[0]}/32`;
|
||||
|
||||
if (siteResource.mode == "host") {
|
||||
let destination = siteResource.destination;
|
||||
// check if this is a valid ip
|
||||
const ipSchema = z.union([z.ipv4(), z.ipv6()]);
|
||||
if (ipSchema.safeParse(destination).success) {
|
||||
destination = `${destination}/32`;
|
||||
|
||||
targets.push({
|
||||
sourcePrefix: clientPrefix,
|
||||
destPrefix: destination
|
||||
});
|
||||
}
|
||||
|
||||
if (siteResource.alias && siteResource.aliasAddress) {
|
||||
// also push a match for the alias address
|
||||
targets.push({
|
||||
sourcePrefix: clientPrefix,
|
||||
destPrefix: `${siteResource.aliasAddress}/32`,
|
||||
rewriteTo: destination
|
||||
});
|
||||
}
|
||||
} else if (siteResource.mode == "cidr") {
|
||||
targets.push({
|
||||
sourcePrefix: clientPrefix,
|
||||
destPrefix: siteResource.destination
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// print a nice representation of the targets
|
||||
// logger.debug(
|
||||
// `Generated subnet proxy targets for: ${JSON.stringify(targets, null, 2)}`
|
||||
// );
|
||||
|
||||
return targets;
|
||||
}
|
||||
|
||||
111
server/lib/lock.ts
Normal file
111
server/lib/lock.ts
Normal file
@@ -0,0 +1,111 @@
|
||||
export class LockManager {
|
||||
/**
|
||||
* Acquire a distributed lock using Redis SET with NX and PX options
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
* @param ttlMs - Time to live in milliseconds
|
||||
* @returns Promise<boolean> - true if lock acquired, false otherwise
|
||||
*/
|
||||
async acquireLock(
|
||||
lockKey: string,
|
||||
ttlMs: number = 30000
|
||||
): Promise<boolean> {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Release a lock using Lua script to ensure atomicity
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
*/
|
||||
async releaseLock(lockKey: string): Promise<void> {}
|
||||
|
||||
/**
|
||||
* Force release a lock regardless of owner (use with caution)
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
*/
|
||||
async forceReleaseLock(lockKey: string): Promise<void> {}
|
||||
|
||||
/**
|
||||
* Check if a lock exists and get its info
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
* @returns Promise<{exists: boolean, ownedByMe: boolean, ttl: number}>
|
||||
*/
|
||||
async getLockInfo(lockKey: string): Promise<{
|
||||
exists: boolean;
|
||||
ownedByMe: boolean;
|
||||
ttl: number;
|
||||
owner?: string;
|
||||
}> {
|
||||
return { exists: true, ownedByMe: true, ttl: 0 };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extend the TTL of an existing lock owned by this worker
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
* @param ttlMs - New TTL in milliseconds
|
||||
* @returns Promise<boolean> - true if extended successfully
|
||||
*/
|
||||
async extendLock(lockKey: string, ttlMs: number): Promise<boolean> {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempt to acquire lock with retries and exponential backoff
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
* @param ttlMs - Time to live in milliseconds
|
||||
* @param maxRetries - Maximum number of retry attempts
|
||||
* @param baseDelayMs - Base delay between retries in milliseconds
|
||||
* @returns Promise<boolean> - true if lock acquired
|
||||
*/
|
||||
async acquireLockWithRetry(
|
||||
lockKey: string,
|
||||
ttlMs: number = 30000,
|
||||
maxRetries: number = 5,
|
||||
baseDelayMs: number = 100
|
||||
): Promise<boolean> {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a function while holding a lock
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
* @param fn - Function to execute while holding the lock
|
||||
* @param ttlMs - Lock TTL in milliseconds
|
||||
* @returns Promise<T> - Result of the executed function
|
||||
*/
|
||||
async withLock<T>(
|
||||
lockKey: string,
|
||||
fn: () => Promise<T>,
|
||||
ttlMs: number = 30000
|
||||
): Promise<T> {
|
||||
const acquired = await this.acquireLock(lockKey, ttlMs);
|
||||
|
||||
if (!acquired) {
|
||||
throw new Error(`Failed to acquire lock: ${lockKey}`);
|
||||
}
|
||||
|
||||
try {
|
||||
return await fn();
|
||||
} finally {
|
||||
await this.releaseLock(lockKey);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up expired locks - Redis handles this automatically, but this method
|
||||
* can be used to get statistics about locks
|
||||
* @returns Promise<{activeLocksCount: number, locksOwnedByMe: number}>
|
||||
*/
|
||||
async getLockStatistics(): Promise<{
|
||||
activeLocksCount: number;
|
||||
locksOwnedByMe: number;
|
||||
}> {
|
||||
return { activeLocksCount: 0, locksOwnedByMe: 0 };
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the Redis connection
|
||||
*/
|
||||
async disconnect(): Promise<void> {}
|
||||
}
|
||||
|
||||
export const lockManager = new LockManager();
|
||||
@@ -229,6 +229,11 @@ export const configSchema = z
|
||||
.default(51820)
|
||||
.transform(stoi)
|
||||
.pipe(portSchema),
|
||||
clients_start_port: portSchema
|
||||
.optional()
|
||||
.default(21820)
|
||||
.transform(stoi)
|
||||
.pipe(portSchema),
|
||||
base_endpoint: z
|
||||
.string()
|
||||
.optional()
|
||||
@@ -249,12 +254,14 @@ export const configSchema = z
|
||||
orgs: z
|
||||
.object({
|
||||
block_size: z.number().positive().gt(0).optional().default(24),
|
||||
subnet_group: z.string().optional().default("100.90.128.0/24")
|
||||
subnet_group: z.string().optional().default("100.90.128.0/24"),
|
||||
utility_subnet_group: z.string().optional().default("100.96.128.0/24") //just hardcode this for now as well
|
||||
})
|
||||
.optional()
|
||||
.default({
|
||||
block_size: 24,
|
||||
subnet_group: "100.90.128.0/24"
|
||||
subnet_group: "100.90.128.0/24",
|
||||
utility_subnet_group: "100.96.128.0/24"
|
||||
}),
|
||||
rate_limits: z
|
||||
.object({
|
||||
@@ -318,8 +325,7 @@ export const configSchema = z
|
||||
enable_integration_api: z.boolean().optional(),
|
||||
disable_local_sites: z.boolean().optional(),
|
||||
disable_basic_wireguard_sites: z.boolean().optional(),
|
||||
disable_config_managed_domains: z.boolean().optional(),
|
||||
enable_clients: z.boolean().optional().default(true)
|
||||
disable_config_managed_domains: z.boolean().optional()
|
||||
})
|
||||
.optional(),
|
||||
dns: z
|
||||
|
||||
1243
server/lib/rebuildClientAssociations.ts
Normal file
1243
server/lib/rebuildClientAssociations.ts
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ import { getHostMeta } from "./hostMeta";
|
||||
import logger from "@server/logger";
|
||||
import { apiKeys, db, roles } from "@server/db";
|
||||
import { sites, users, orgs, resources, clients, idp } from "@server/db";
|
||||
import { eq, count, notInArray } from "drizzle-orm";
|
||||
import { eq, count, notInArray, and } from "drizzle-orm";
|
||||
import { APP_VERSION } from "./consts";
|
||||
import crypto from "crypto";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
@@ -113,7 +113,12 @@ class TelemetryClient {
|
||||
const [customRoles] = await db
|
||||
.select({ count: count() })
|
||||
.from(roles)
|
||||
.where(notInArray(roles.name, ["Admin", "Member"]));
|
||||
.where(
|
||||
and(
|
||||
eq(roles.isAdmin, false),
|
||||
notInArray(roles.name, ["Member"])
|
||||
)
|
||||
);
|
||||
|
||||
const adminUsers = await db
|
||||
.select({ email: users.email })
|
||||
|
||||
@@ -345,9 +345,9 @@ export async function getTraefikConfig(
|
||||
routerMiddlewares.push(rewriteMiddlewareName);
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`Created path rewrite middleware ${rewriteMiddlewareName}: ${resource.pathMatchType}(${resource.path}) -> ${resource.rewritePathType}(${resource.rewritePath})`
|
||||
);
|
||||
// logger.debug(
|
||||
// `Created path rewrite middleware ${rewriteMiddlewareName}: ${resource.pathMatchType}(${resource.path}) -> ${resource.rewritePathType}(${resource.rewritePath})`
|
||||
// );
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to create path rewrite middleware for resource ${resource.resourceId}: ${error}`
|
||||
|
||||
@@ -11,6 +11,7 @@ export * from "./verifyRoleAccess";
|
||||
export * from "./verifyUserAccess";
|
||||
export * from "./verifyAdmin";
|
||||
export * from "./verifySetResourceUsers";
|
||||
export * from "./verifySetResourceClients";
|
||||
export * from "./verifyUserInRole";
|
||||
export * from "./verifyAccessTokenAccess";
|
||||
export * from "./requestTimeout";
|
||||
@@ -24,7 +25,7 @@ export * from "./integration";
|
||||
export * from "./verifyUserHasAction";
|
||||
export * from "./verifyApiKeyAccess";
|
||||
export * from "./verifyDomainAccess";
|
||||
export * from "./verifyClientsEnabled";
|
||||
export * from "./verifyUserIsOrgOwner";
|
||||
export * from "./verifySiteResourceAccess";
|
||||
export * from "./logActionAudit";
|
||||
export * from "./logActionAudit";
|
||||
export * from "./verifyOlmAccess";
|
||||
|
||||
@@ -7,6 +7,7 @@ export * from "./verifyApiKeyTargetAccess";
|
||||
export * from "./verifyApiKeyRoleAccess";
|
||||
export * from "./verifyApiKeyUserAccess";
|
||||
export * from "./verifyApiKeySetResourceUsers";
|
||||
export * from "./verifyApiKeySetResourceClients";
|
||||
export * from "./verifyAccessTokenAccess";
|
||||
export * from "./verifyApiKeyIsRoot";
|
||||
export * from "./verifyApiKeyApiKeyAccess";
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { db } from "@server/db";
|
||||
import { clients } from "@server/db";
|
||||
import { and, eq, inArray } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
|
||||
export async function verifyApiKeySetResourceClients(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) {
|
||||
const apiKey = req.apiKey;
|
||||
const singleClientId = req.params.clientId || req.body.clientId || req.query.clientId;
|
||||
const { clientIds } = req.body;
|
||||
const allClientIds = clientIds || (singleClientId ? [parseInt(singleClientId as string)] : []);
|
||||
|
||||
if (!apiKey) {
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "Key not authenticated")
|
||||
);
|
||||
}
|
||||
|
||||
if (apiKey.isRoot) {
|
||||
// Root keys can access any client in any org
|
||||
return next();
|
||||
}
|
||||
|
||||
if (!req.apiKeyOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Key does not have access to this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (allClientIds.length === 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
try {
|
||||
const orgId = req.apiKeyOrg.orgId;
|
||||
const clientsData = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(
|
||||
and(
|
||||
inArray(clients.clientId, allClientIds),
|
||||
eq(clients.orgId, orgId)
|
||||
)
|
||||
);
|
||||
|
||||
if (clientsData.length !== allClientIds.length) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Key does not have access to one or more specified clients"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return next();
|
||||
} catch (error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Error checking if key has access to the specified clients"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,9 @@ export async function verifyApiKeySetResourceUsers(
|
||||
next: NextFunction
|
||||
) {
|
||||
const apiKey = req.apiKey;
|
||||
const userIds = req.body.userIds;
|
||||
const singleUserId = req.params.userId || req.body.userId || req.query.userId;
|
||||
const { userIds } = req.body;
|
||||
const allUserIds = userIds || (singleUserId ? [singleUserId] : []);
|
||||
|
||||
if (!apiKey) {
|
||||
return next(
|
||||
@@ -33,11 +35,7 @@ export async function verifyApiKeySetResourceUsers(
|
||||
);
|
||||
}
|
||||
|
||||
if (!userIds) {
|
||||
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid user IDs"));
|
||||
}
|
||||
|
||||
if (userIds.length === 0) {
|
||||
if (allUserIds.length === 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
@@ -48,12 +46,12 @@ export async function verifyApiKeySetResourceUsers(
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
inArray(userOrgs.userId, userIds),
|
||||
inArray(userOrgs.userId, allUserIds),
|
||||
eq(userOrgs.orgId, orgId)
|
||||
)
|
||||
);
|
||||
|
||||
if (userOrgsData.length !== userIds.length) {
|
||||
if (userOrgsData.length !== allUserIds.length) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
|
||||
@@ -13,8 +13,6 @@ export async function verifyApiKeySiteResourceAccess(
|
||||
try {
|
||||
const apiKey = req.apiKey;
|
||||
const siteResourceId = parseInt(req.params.siteResourceId);
|
||||
const siteId = parseInt(req.params.siteId);
|
||||
const orgId = req.params.orgId;
|
||||
|
||||
if (!apiKey) {
|
||||
return next(
|
||||
@@ -22,11 +20,11 @@ export async function verifyApiKeySiteResourceAccess(
|
||||
);
|
||||
}
|
||||
|
||||
if (!siteResourceId || !siteId || !orgId) {
|
||||
if (!siteResourceId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Missing required parameters"
|
||||
"Missing siteResourceId parameter"
|
||||
)
|
||||
);
|
||||
}
|
||||
@@ -41,9 +39,7 @@ export async function verifyApiKeySiteResourceAccess(
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.where(and(
|
||||
eq(siteResources.siteResourceId, siteResourceId),
|
||||
eq(siteResources.siteId, siteId),
|
||||
eq(siteResources.orgId, orgId)
|
||||
eq(siteResources.siteResourceId, siteResourceId)
|
||||
))
|
||||
.limit(1);
|
||||
|
||||
@@ -64,11 +60,11 @@ export async function verifyApiKeySiteResourceAccess(
|
||||
.where(
|
||||
and(
|
||||
eq(apiKeyOrg.apiKeyId, apiKey.apiKeyId),
|
||||
eq(apiKeyOrg.orgId, orgId)
|
||||
eq(apiKeyOrg.orgId, siteResource.orgId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
|
||||
if (apiKeyOrgRes.length === 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
@@ -77,12 +73,11 @@ export async function verifyApiKeySiteResourceAccess(
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
req.apiKeyOrg = apiKeyOrgRes[0];
|
||||
}
|
||||
|
||||
// Attach the siteResource to the request for use in the next middleware/route
|
||||
// @ts-ignore - Extending Request type
|
||||
req.siteResource = siteResource;
|
||||
|
||||
return next();
|
||||
|
||||
@@ -5,6 +5,7 @@ import { and, eq } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { canUserAccessResource } from "@server/auth/canUserAccessResource";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifyAccessTokenAccess(
|
||||
req: Request,
|
||||
@@ -96,6 +97,24 @@ export async function verifyAccessTokenAccess(
|
||||
req.userOrgId = resource[0].orgId!;
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const resourceAllowed = await canUserAccessResource({
|
||||
userId,
|
||||
resourceId,
|
||||
|
||||
@@ -4,6 +4,7 @@ import { roles, userOrgs } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifyAdmin(
|
||||
req: Request,
|
||||
@@ -43,6 +44,24 @@ export async function verifyAdmin(
|
||||
);
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const userRole = await db
|
||||
.select()
|
||||
.from(roles)
|
||||
|
||||
@@ -4,6 +4,7 @@ import { userOrgs, apiKeys, apiKeyOrg } from "@server/db";
|
||||
import { and, eq, or } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifyApiKeyAccess(
|
||||
req: Request,
|
||||
@@ -84,6 +85,24 @@ export async function verifyApiKeyAccess(
|
||||
);
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const userOrgRoleId = req.userOrg.roleId;
|
||||
req.userOrgRoleId = userOrgRoleId;
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import { userOrgs, clients, roleClients, userClients } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifyClientAccess(
|
||||
req: Request,
|
||||
@@ -75,6 +76,24 @@ export async function verifyClientAccess(
|
||||
);
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const userOrgRoleId = req.userOrg.roleId;
|
||||
req.userOrgRoleId = userOrgRoleId;
|
||||
req.userOrgId = client.orgId;
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
export async function verifyClientsEnabled(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) {
|
||||
try {
|
||||
if (!config.getRawConfig().flags?.enable_clients) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_IMPLEMENTED,
|
||||
"Clients are not enabled on this server."
|
||||
)
|
||||
);
|
||||
}
|
||||
return next();
|
||||
} catch (error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to check if clients are enabled"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import { userOrgs, apiKeyOrg } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifyDomainAccess(
|
||||
req: Request,
|
||||
@@ -78,6 +79,24 @@ export async function verifyDomainAccess(
|
||||
);
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const userOrgRoleId = req.userOrg.roleId;
|
||||
req.userOrgRoleId = userOrgRoleId;
|
||||
|
||||
|
||||
45
server/middlewares/verifyOlmAccess.ts
Normal file
45
server/middlewares/verifyOlmAccess.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { db, olms } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
|
||||
export async function verifyOlmAccess(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) {
|
||||
try {
|
||||
const userId = req.user!.userId;
|
||||
const olmId = req.params.olmId || req.body.olmId || req.query.olmId;
|
||||
|
||||
if (!userId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
|
||||
);
|
||||
}
|
||||
|
||||
const [existingOlm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(and(eq(olms.olmId, olmId), eq(olms.userId, userId)));
|
||||
|
||||
if (!existingOlm) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not have access to this olm"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return next();
|
||||
} catch (error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Error checking if user has access to this user"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -47,22 +47,22 @@ export async function verifyOrgAccess(
|
||||
);
|
||||
}
|
||||
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
|
||||
logger.debug("Org check policy result", { policyCheck });
|
||||
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
if (req.orgPolicyAllowed === undefined) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// User has access, attach the user's role to the request for potential future use
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { db } from "@server/db";
|
||||
import {
|
||||
resources,
|
||||
userOrgs,
|
||||
userResources,
|
||||
roleResources,
|
||||
} from "@server/db";
|
||||
import { resources, userOrgs, userResources, roleResources } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifyResourceAccess(
|
||||
req: Request,
|
||||
@@ -73,6 +69,24 @@ export async function verifyResourceAccess(
|
||||
);
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const userOrgRoleId = req.userOrg.roleId;
|
||||
req.userOrgRoleId = userOrgRoleId;
|
||||
req.userOrgId = resource[0].orgId;
|
||||
|
||||
@@ -5,6 +5,7 @@ import { and, eq, inArray } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import logger from "@server/logger";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifyRoleAccess(
|
||||
req: Request,
|
||||
@@ -105,6 +106,33 @@ export async function verifyRoleAccess(
|
||||
req.userOrgRoleId = userOrg[0].roleId;
|
||||
}
|
||||
|
||||
if (!req.userOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not have access to this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return next();
|
||||
} catch (error) {
|
||||
logger.error("Error verifying role access:", error);
|
||||
@@ -116,4 +144,3 @@ export async function verifyRoleAccess(
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import { NextFunction, Response } from "express";
|
||||
import ErrorResponse from "@server/types/ErrorResponse";
|
||||
import { db } from "@server/db";
|
||||
import { users } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { verifySession } from "@server/auth/sessions/verifySession";
|
||||
import { unauthorized } from "@server/auth/unauthorizedResponse";
|
||||
|
||||
@@ -13,24 +8,15 @@ export const verifySessionMiddleware = async (
|
||||
res: Response<ErrorResponse>,
|
||||
next: NextFunction
|
||||
) => {
|
||||
const { session, user } = await verifySession(req);
|
||||
const { forceLogin } = req.query;
|
||||
|
||||
const { session, user } = await verifySession(req, forceLogin === "true");
|
||||
if (!session || !user) {
|
||||
return next(unauthorized());
|
||||
}
|
||||
|
||||
const existingUser = await db
|
||||
.select()
|
||||
.from(users)
|
||||
.where(eq(users.userId, user.userId));
|
||||
|
||||
if (!existingUser || !existingUser[0]) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "User does not exist")
|
||||
);
|
||||
}
|
||||
|
||||
req.user = existingUser[0];
|
||||
req.user = user;
|
||||
req.session = session;
|
||||
|
||||
next();
|
||||
return next();
|
||||
};
|
||||
|
||||
90
server/middlewares/verifySetResourceClients.ts
Normal file
90
server/middlewares/verifySetResourceClients.ts
Normal file
@@ -0,0 +1,90 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { db } from "@server/db";
|
||||
import { clients } from "@server/db";
|
||||
import { and, eq, inArray } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifySetResourceClients(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) {
|
||||
const userId = req.user!.userId;
|
||||
const singleClientId =
|
||||
req.params.clientId || req.body.clientId || req.query.clientId;
|
||||
const { clientIds } = req.body;
|
||||
const allClientIds =
|
||||
clientIds ||
|
||||
(singleClientId ? [parseInt(singleClientId as string)] : []);
|
||||
|
||||
if (!userId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
|
||||
);
|
||||
}
|
||||
|
||||
if (!req.userOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not have access to this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (allClientIds.length === 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
try {
|
||||
const orgId = req.userOrg.orgId;
|
||||
// get all clients for the clientIds
|
||||
const clientsData = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(
|
||||
and(
|
||||
inArray(clients.clientId, allClientIds),
|
||||
eq(clients.orgId, orgId)
|
||||
)
|
||||
);
|
||||
|
||||
if (clientsData.length !== allClientIds.length) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not have access to one or more specified clients"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return next();
|
||||
} catch (error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Error checking if user has access to the specified clients"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import { userOrgs } from "@server/db";
|
||||
import { and, eq, inArray, or } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifySetResourceUsers(
|
||||
req: Request,
|
||||
@@ -28,6 +29,24 @@ export async function verifySetResourceUsers(
|
||||
);
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (!userIds) {
|
||||
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid user IDs"));
|
||||
}
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { db } from "@server/db";
|
||||
import {
|
||||
sites,
|
||||
userOrgs,
|
||||
userSites,
|
||||
roleSites,
|
||||
roles,
|
||||
} from "@server/db";
|
||||
import { sites, userOrgs, userSites, roleSites, roles } from "@server/db";
|
||||
import { and, eq, or } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import logger from "@server/logger";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifySiteAccess(
|
||||
req: Request,
|
||||
@@ -82,6 +77,24 @@ export async function verifySiteAccess(
|
||||
);
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const userOrgRoleId = req.userOrg.roleId;
|
||||
req.userOrgRoleId = userOrgRoleId;
|
||||
req.userOrgId = site[0].orgId;
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { db } from "@server/db";
|
||||
import { db, roleSiteResources, userOrgs, userSiteResources } from "@server/db";
|
||||
import { siteResources } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import logger from "@server/logger";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifySiteResourceAccess(
|
||||
req: Request,
|
||||
@@ -12,44 +13,145 @@ export async function verifySiteResourceAccess(
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const siteResourceId = parseInt(req.params.siteResourceId);
|
||||
const siteId = parseInt(req.params.siteId);
|
||||
const orgId = req.params.orgId;
|
||||
const userId = req.user!.userId;
|
||||
const siteResourceId =
|
||||
req.params.siteResourceId ||
|
||||
req.body.siteResourceId ||
|
||||
req.query.siteResourceId;
|
||||
|
||||
if (!siteResourceId || !siteId || !orgId) {
|
||||
if (!userId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
|
||||
);
|
||||
}
|
||||
|
||||
if (!siteResourceId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Missing required parameters"
|
||||
"Site resource ID is required"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const siteResourceIdNum = parseInt(siteResourceId as string, 10);
|
||||
if (isNaN(siteResourceIdNum)) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Invalid site resource ID"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Check if the site resource exists and belongs to the specified site and org
|
||||
const [siteResource] = await db
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.where(and(
|
||||
eq(siteResources.siteResourceId, siteResourceId),
|
||||
eq(siteResources.siteId, siteId),
|
||||
eq(siteResources.orgId, orgId)
|
||||
))
|
||||
.where(eq(siteResources.siteResourceId, siteResourceIdNum))
|
||||
.limit(1);
|
||||
|
||||
if (!siteResource) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"Site resource not found"
|
||||
`Site resource with ID ${siteResourceIdNum} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!siteResource.orgId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
`Site resource with ID ${siteResourceIdNum} does not have an organization ID`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!req.userOrg) {
|
||||
const userOrgRole = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.userId, userId),
|
||||
eq(userOrgs.orgId, siteResource.orgId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
req.userOrg = userOrgRole[0];
|
||||
}
|
||||
|
||||
if (!req.userOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not have access to this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const userOrgRoleId = req.userOrg.roleId;
|
||||
req.userOrgRoleId = userOrgRoleId;
|
||||
req.userOrgId = siteResource.orgId;
|
||||
|
||||
// Attach the siteResource to the request for use in the next middleware/route
|
||||
// @ts-ignore - Extending Request type
|
||||
req.siteResource = siteResource;
|
||||
|
||||
next();
|
||||
const roleResourceAccess = await db
|
||||
.select()
|
||||
.from(roleSiteResources)
|
||||
.where(
|
||||
and(
|
||||
eq(roleSiteResources.siteResourceId, siteResourceIdNum),
|
||||
eq(roleSiteResources.roleId, userOrgRoleId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (roleResourceAccess.length > 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const userResourceAccess = await db
|
||||
.select()
|
||||
.from(userSiteResources)
|
||||
.where(
|
||||
and(
|
||||
eq(userSiteResources.userId, userId),
|
||||
eq(userSiteResources.siteResourceId, siteResourceIdNum)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (userResourceAccess.length > 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not have access to this resource"
|
||||
)
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Error verifying site resource access:", error);
|
||||
return next(
|
||||
|
||||
@@ -5,6 +5,7 @@ import { and, eq } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { canUserAccessResource } from "../auth/canUserAccessResource";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifyTargetAccess(
|
||||
req: Request,
|
||||
@@ -102,6 +103,26 @@ export async function verifyTargetAccess(
|
||||
req.userOrgId = resource[0].orgId!;
|
||||
}
|
||||
|
||||
const orgId = req.userOrg.orgId;
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const resourceAllowed = await canUserAccessResource({
|
||||
userId,
|
||||
resourceId,
|
||||
|
||||
@@ -15,7 +15,9 @@ export const verifySessionUserMiddleware = async (
|
||||
res: Response<ErrorResponse>,
|
||||
next: NextFunction
|
||||
) => {
|
||||
const { session, user } = await verifySession(req);
|
||||
const { forceLogin } = req.query;
|
||||
|
||||
const { session, user } = await verifySession(req, forceLogin === "true");
|
||||
if (!session || !user) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(`User session not found. IP: ${req.ip}.`);
|
||||
|
||||
@@ -4,6 +4,7 @@ import { userOrgs } from "@server/db";
|
||||
import { and, eq, or } from "drizzle-orm";
|
||||
import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
|
||||
export async function verifyUserAccess(
|
||||
req: Request,
|
||||
@@ -47,6 +48,24 @@ export async function verifyUserAccess(
|
||||
);
|
||||
}
|
||||
|
||||
if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) {
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: req.userOrg.orgId,
|
||||
userId,
|
||||
session: req.session
|
||||
});
|
||||
req.orgPolicyAllowed = policyCheck.allowed;
|
||||
if (!policyCheck.allowed || policyCheck.error) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Failed organization access policy check: " +
|
||||
(policyCheck.error || "Unknown error")
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return next();
|
||||
} catch (error) {
|
||||
return next(
|
||||
|
||||
@@ -45,6 +45,11 @@ export class PrivateConfig {
|
||||
|
||||
this.rawPrivateConfig = parsedPrivateConfig;
|
||||
|
||||
process.env.BRANDING_HIDE_AUTH_LAYOUT_FOOTER =
|
||||
this.rawPrivateConfig.branding?.hide_auth_layout_footer === true
|
||||
? "true"
|
||||
: "false";
|
||||
|
||||
if (this.rawPrivateConfig.branding?.colors) {
|
||||
process.env.BRANDING_COLORS = JSON.stringify(
|
||||
this.rawPrivateConfig.branding?.colors
|
||||
|
||||
@@ -197,7 +197,7 @@ export async function listExitNodes(orgId: string, filterOnline = false, noCloud
|
||||
|
||||
// // set the item in the database if it is offline
|
||||
// if (isActuallyOnline != node.online) {
|
||||
// await db
|
||||
// await trx
|
||||
// .update(exitNodes)
|
||||
// .set({ online: isActuallyOnline })
|
||||
// .where(eq(exitNodes.exitNodeId, node.exitNodeId));
|
||||
|
||||
363
server/private/lib/lock.ts
Normal file
363
server/private/lib/lock.ts
Normal file
@@ -0,0 +1,363 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { config } from "@server/lib/config";
|
||||
import logger from "@server/logger";
|
||||
import { redis } from "#private/lib/redis";
|
||||
|
||||
export class LockManager {
|
||||
/**
|
||||
* Acquire a distributed lock using Redis SET with NX and PX options
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
* @param ttlMs - Time to live in milliseconds
|
||||
* @returns Promise<boolean> - true if lock acquired, false otherwise
|
||||
*/
|
||||
async acquireLock(
|
||||
lockKey: string,
|
||||
ttlMs: number = 30000
|
||||
): Promise<boolean> {
|
||||
if (!redis || !redis.status || redis.status !== "ready") {
|
||||
return true;
|
||||
}
|
||||
|
||||
const lockValue = `${
|
||||
config.getRawConfig().gerbil.exit_node_name
|
||||
}:${Date.now()}`;
|
||||
const redisKey = `lock:${lockKey}`;
|
||||
|
||||
try {
|
||||
// Use SET with NX (only set if not exists) and PX (expire in milliseconds)
|
||||
// This is atomic and handles both setting and expiration
|
||||
const result = await redis.set(
|
||||
redisKey,
|
||||
lockValue,
|
||||
"PX",
|
||||
ttlMs,
|
||||
"NX"
|
||||
);
|
||||
|
||||
if (result === "OK") {
|
||||
logger.debug(
|
||||
`Lock acquired: ${lockKey} by ${
|
||||
config.getRawConfig().gerbil.exit_node_name
|
||||
}`
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if the existing lock is from this worker (reentrant behavior)
|
||||
const existingValue = await redis.get(redisKey);
|
||||
if (
|
||||
existingValue &&
|
||||
existingValue.startsWith(
|
||||
`${config.getRawConfig().gerbil.exit_node_name}:`
|
||||
)
|
||||
) {
|
||||
// Extend the lock TTL since it's the same worker
|
||||
await redis.pexpire(redisKey, ttlMs);
|
||||
logger.debug(
|
||||
`Lock extended: ${lockKey} by ${
|
||||
config.getRawConfig().gerbil.exit_node_name
|
||||
}`
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to acquire lock ${lockKey}:`, error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Release a lock using Lua script to ensure atomicity
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
*/
|
||||
async releaseLock(lockKey: string): Promise<void> {
|
||||
if (!redis || !redis.status || redis.status !== "ready") {
|
||||
return;
|
||||
}
|
||||
|
||||
const redisKey = `lock:${lockKey}`;
|
||||
|
||||
// Lua script to ensure we only delete the lock if it belongs to this worker
|
||||
const luaScript = `
|
||||
local key = KEYS[1]
|
||||
local worker_prefix = ARGV[1]
|
||||
local current_value = redis.call('GET', key)
|
||||
|
||||
if current_value and string.find(current_value, worker_prefix, 1, true) == 1 then
|
||||
return redis.call('DEL', key)
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`;
|
||||
|
||||
try {
|
||||
const result = (await redis.eval(
|
||||
luaScript,
|
||||
1,
|
||||
redisKey,
|
||||
`${config.getRawConfig().gerbil.exit_node_name}:`
|
||||
)) as number;
|
||||
|
||||
if (result === 1) {
|
||||
logger.debug(
|
||||
`Lock released: ${lockKey} by ${
|
||||
config.getRawConfig().gerbil.exit_node_name
|
||||
}`
|
||||
);
|
||||
} else {
|
||||
logger.warn(
|
||||
`Lock not released - not owned by worker: ${lockKey} by ${
|
||||
config.getRawConfig().gerbil.exit_node_name
|
||||
}`
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to release lock ${lockKey}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Force release a lock regardless of owner (use with caution)
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
*/
|
||||
async forceReleaseLock(lockKey: string): Promise<void> {
|
||||
if (!redis || !redis.status || redis.status !== "ready") {
|
||||
return;
|
||||
}
|
||||
|
||||
const redisKey = `lock:${lockKey}`;
|
||||
|
||||
try {
|
||||
const result = await redis.del(redisKey);
|
||||
if (result === 1) {
|
||||
logger.debug(`Lock force released: ${lockKey}`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to force release lock ${lockKey}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a lock exists and get its info
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
* @returns Promise<{exists: boolean, ownedByMe: boolean, ttl: number}>
|
||||
*/
|
||||
async getLockInfo(lockKey: string): Promise<{
|
||||
exists: boolean;
|
||||
ownedByMe: boolean;
|
||||
ttl: number;
|
||||
owner?: string;
|
||||
}> {
|
||||
if (!redis || !redis.status || redis.status !== "ready") {
|
||||
return { exists: false, ownedByMe: true, ttl: 0 };
|
||||
}
|
||||
|
||||
const redisKey = `lock:${lockKey}`;
|
||||
|
||||
try {
|
||||
const [value, ttl] = await Promise.all([
|
||||
redis.get(redisKey),
|
||||
redis.pttl(redisKey)
|
||||
]);
|
||||
|
||||
const exists = value !== null;
|
||||
const ownedByMe =
|
||||
exists &&
|
||||
value!.startsWith(`${config.getRawConfig().gerbil.exit_node_name}:`);
|
||||
const owner = exists ? value!.split(":")[0] : undefined;
|
||||
|
||||
return {
|
||||
exists,
|
||||
ownedByMe,
|
||||
ttl: ttl > 0 ? ttl : 0,
|
||||
owner
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`Failed to get lock info ${lockKey}:`, error);
|
||||
return { exists: false, ownedByMe: false, ttl: 0 };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extend the TTL of an existing lock owned by this worker
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
* @param ttlMs - New TTL in milliseconds
|
||||
* @returns Promise<boolean> - true if extended successfully
|
||||
*/
|
||||
async extendLock(lockKey: string, ttlMs: number): Promise<boolean> {
|
||||
if (!redis || !redis.status || redis.status !== "ready") {
|
||||
return true;
|
||||
}
|
||||
|
||||
const redisKey = `lock:${lockKey}`;
|
||||
|
||||
// Lua script to extend TTL only if lock is owned by this worker
|
||||
const luaScript = `
|
||||
local key = KEYS[1]
|
||||
local worker_prefix = ARGV[1]
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current_value = redis.call('GET', key)
|
||||
|
||||
if current_value and string.find(current_value, worker_prefix, 1, true) == 1 then
|
||||
return redis.call('PEXPIRE', key, ttl)
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`;
|
||||
|
||||
try {
|
||||
const result = (await redis.eval(
|
||||
luaScript,
|
||||
1,
|
||||
redisKey,
|
||||
`${config.getRawConfig().gerbil.exit_node_name}:`,
|
||||
ttlMs.toString()
|
||||
)) as number;
|
||||
|
||||
if (result === 1) {
|
||||
logger.debug(
|
||||
`Lock extended: ${lockKey} by ${
|
||||
config.getRawConfig().gerbil.exit_node_name
|
||||
} for ${ttlMs}ms`
|
||||
);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to extend lock ${lockKey}:`, error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempt to acquire lock with retries and exponential backoff
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
* @param ttlMs - Time to live in milliseconds
|
||||
* @param maxRetries - Maximum number of retry attempts
|
||||
* @param baseDelayMs - Base delay between retries in milliseconds
|
||||
* @returns Promise<boolean> - true if lock acquired
|
||||
*/
|
||||
async acquireLockWithRetry(
|
||||
lockKey: string,
|
||||
ttlMs: number = 30000,
|
||||
maxRetries: number = 5,
|
||||
baseDelayMs: number = 100
|
||||
): Promise<boolean> {
|
||||
if (!redis || !redis.status || redis.status !== "ready") {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (let attempt = 0; attempt <= maxRetries; attempt++) {
|
||||
const acquired = await this.acquireLock(lockKey, ttlMs);
|
||||
|
||||
if (acquired) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (attempt < maxRetries) {
|
||||
// Exponential backoff with jitter
|
||||
const delay =
|
||||
baseDelayMs * Math.pow(2, attempt) + Math.random() * 100;
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
}
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
`Failed to acquire lock ${lockKey} after ${maxRetries + 1} attempts`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a function while holding a lock
|
||||
* @param lockKey - Unique identifier for the lock
|
||||
* @param fn - Function to execute while holding the lock
|
||||
* @param ttlMs - Lock TTL in milliseconds
|
||||
* @returns Promise<T> - Result of the executed function
|
||||
*/
|
||||
async withLock<T>(
|
||||
lockKey: string,
|
||||
fn: () => Promise<T>,
|
||||
ttlMs: number = 30000
|
||||
): Promise<T> {
|
||||
if (!redis || !redis.status || redis.status !== "ready") {
|
||||
return await fn();
|
||||
}
|
||||
|
||||
const acquired = await this.acquireLock(lockKey, ttlMs);
|
||||
|
||||
if (!acquired) {
|
||||
throw new Error(`Failed to acquire lock: ${lockKey}`);
|
||||
}
|
||||
|
||||
try {
|
||||
return await fn();
|
||||
} finally {
|
||||
await this.releaseLock(lockKey);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up expired locks - Redis handles this automatically, but this method
|
||||
* can be used to get statistics about locks
|
||||
* @returns Promise<{activeLocksCount: number, locksOwnedByMe: number}>
|
||||
*/
|
||||
async getLockStatistics(): Promise<{
|
||||
activeLocksCount: number;
|
||||
locksOwnedByMe: number;
|
||||
}> {
|
||||
if (!redis || !redis.status || redis.status !== "ready") {
|
||||
return { activeLocksCount: 0, locksOwnedByMe: 0 };
|
||||
}
|
||||
|
||||
try {
|
||||
const keys = await redis.keys("lock:*");
|
||||
let locksOwnedByMe = 0;
|
||||
|
||||
if (keys.length > 0) {
|
||||
const values = await redis.mget(...keys);
|
||||
locksOwnedByMe = values.filter(
|
||||
(value) =>
|
||||
value &&
|
||||
value.startsWith(
|
||||
`${config.getRawConfig().gerbil.exit_node_name}:`
|
||||
)
|
||||
).length;
|
||||
}
|
||||
|
||||
return {
|
||||
activeLocksCount: keys.length,
|
||||
locksOwnedByMe
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error("Failed to get lock statistics:", error);
|
||||
return { activeLocksCount: 0, locksOwnedByMe: 0 };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the Redis connection
|
||||
*/
|
||||
async disconnect(): Promise<void> {
|
||||
if (!redis || !redis.status || redis.status !== "ready") {
|
||||
return;
|
||||
}
|
||||
await redis.quit();
|
||||
}
|
||||
}
|
||||
|
||||
export const lockManager = new LockManager();
|
||||
@@ -124,6 +124,7 @@ export const privateConfigSchema = z.object({
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
hide_auth_layout_footer: z.boolean().optional().default(false),
|
||||
login_page: z
|
||||
.object({
|
||||
subtitle_text: z.string().optional(),
|
||||
|
||||
@@ -434,9 +434,9 @@ export async function getTraefikConfig(
|
||||
routerMiddlewares.push(rewriteMiddlewareName);
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`Created path rewrite middleware ${rewriteMiddlewareName}: ${resource.pathMatchType}(${resource.path}) -> ${resource.rewritePathType}(${resource.rewritePath})`
|
||||
);
|
||||
// logger.debug(
|
||||
// `Created path rewrite middleware ${rewriteMiddlewareName}: ${resource.pathMatchType}(${resource.path}) -> ${resource.rewritePathType}(${resource.rewritePath})`
|
||||
// );
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to create path rewrite middleware for resource ${resource.resourceId}: ${error}`
|
||||
|
||||
@@ -16,4 +16,5 @@ export * from "./verifyRemoteExitNodeAccess";
|
||||
export * from "./verifyIdpAccess";
|
||||
export * from "./verifyLoginPageAccess";
|
||||
export * from "./logActionAudit";
|
||||
export * from "./verifySubscription";
|
||||
export * from "./verifySubscription";
|
||||
export * from "./verifyValidLicense";
|
||||
|
||||
@@ -31,7 +31,6 @@ import {
|
||||
verifyUserIsServerAdmin,
|
||||
verifySiteAccess,
|
||||
verifyClientAccess,
|
||||
verifyClientsEnabled,
|
||||
} from "@server/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
import {
|
||||
@@ -437,7 +436,6 @@ authenticated.get(
|
||||
|
||||
authenticated.post(
|
||||
"/re-key/:clientId/regenerate-client-secret",
|
||||
verifyClientsEnabled,
|
||||
verifyClientAccess,
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateClientSecret
|
||||
|
||||
@@ -1043,7 +1043,7 @@ hybridRouter.get(
|
||||
);
|
||||
}
|
||||
|
||||
let rules = await db
|
||||
const rules = await db
|
||||
.select()
|
||||
.from(resourceRules)
|
||||
.where(eq(resourceRules.resourceId, resourceId));
|
||||
@@ -1369,7 +1369,7 @@ const updateHolePunchSchema = z.object({
|
||||
port: z.number(),
|
||||
timestamp: z.number(),
|
||||
reachableAt: z.string().optional(),
|
||||
publicKey: z.string().optional()
|
||||
publicKey: z.string() // this is the client public key
|
||||
});
|
||||
hybridRouter.post(
|
||||
"/gerbil/update-hole-punch",
|
||||
@@ -1408,7 +1408,7 @@ hybridRouter.post(
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId, newtId, ip, port, timestamp, token, reachableAt } =
|
||||
const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } =
|
||||
parsedParams.data;
|
||||
|
||||
const destinations = await updateAndGenerateEndpointDestinations(
|
||||
@@ -1418,6 +1418,7 @@ hybridRouter.post(
|
||||
port,
|
||||
timestamp,
|
||||
token,
|
||||
publicKey,
|
||||
exitNode,
|
||||
true
|
||||
);
|
||||
@@ -1742,7 +1743,12 @@ hybridRouter.post(
|
||||
tls: logEntry.tls
|
||||
}));
|
||||
|
||||
await db.insert(requestAuditLog).values(logEntries);
|
||||
// batch them into inserts of 100 to avoid exceeding parameter limits
|
||||
const batchSize = 100;
|
||||
for (let i = 0; i < logEntries.length; i += batchSize) {
|
||||
const batch = logEntries.slice(i, i + batchSize);
|
||||
await db.insert(requestAuditLog).values(batch);
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
|
||||
@@ -13,13 +13,17 @@
|
||||
|
||||
import * as orgIdp from "#private/routers/orgIdp";
|
||||
import * as org from "#private/routers/org";
|
||||
import * as logs from "#private/routers/auditLogs";
|
||||
|
||||
import { Router } from "express";
|
||||
import {
|
||||
verifyApiKey,
|
||||
verifyApiKeyHasAction,
|
||||
verifyApiKeyIsRoot,
|
||||
verifyApiKeyOrgAccess,
|
||||
} from "@server/middlewares";
|
||||
import {
|
||||
verifyValidSubscription,
|
||||
verifyValidLicense
|
||||
} from "#private/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
|
||||
import { unauthenticated as ua, authenticated as a } from "@server/routers/integration";
|
||||
@@ -42,4 +46,42 @@ authenticated.delete(
|
||||
verifyApiKeyHasAction(ActionsEnum.deleteIdp),
|
||||
logActionAudit(ActionsEnum.deleteIdp),
|
||||
orgIdp.deleteOrgIdp,
|
||||
);
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/action",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.exportLogs),
|
||||
logs.queryActionAuditLogs
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/action/export",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.exportLogs),
|
||||
logActionAudit(ActionsEnum.exportLogs),
|
||||
logs.exportActionAuditLogs
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/access",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.exportLogs),
|
||||
logs.queryAccessAuditLogs
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/access/export",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.exportLogs),
|
||||
logActionAudit(ActionsEnum.exportLogs),
|
||||
logs.exportAccessAuditLogs
|
||||
);
|
||||
|
||||
@@ -164,7 +164,7 @@ export async function createLoginPage(
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(and(eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
|
||||
.limit(10);
|
||||
.limit(10);
|
||||
}
|
||||
|
||||
// select a random exit node
|
||||
|
||||
@@ -38,6 +38,7 @@ import { rateLimitService } from "#private/lib/rateLimit";
|
||||
import { messageHandlers } from "@server/routers/ws/messageHandlers";
|
||||
import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers";
|
||||
import { AuthenticatedWebSocket, ClientType, WSMessage, TokenPayload, WebSocketRequest, RedisMessage } from "@server/routers/ws";
|
||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
|
||||
// Merge public and private message handlers
|
||||
Object.assign(messageHandlers, privateMessageHandlers);
|
||||
@@ -370,6 +371,9 @@ const sendToClientLocal = async (
|
||||
client.send(messageString);
|
||||
}
|
||||
});
|
||||
|
||||
logger.debug(`sendToClient: Message type ${message.type} sent to clientId ${clientId}`);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
@@ -478,7 +482,8 @@ const getActiveNodes = async (
|
||||
// Token verification middleware
|
||||
const verifyToken = async (
|
||||
token: string,
|
||||
clientType: ClientType
|
||||
clientType: ClientType,
|
||||
userToken: string
|
||||
): Promise<TokenPayload | null> => {
|
||||
try {
|
||||
if (clientType === "newt") {
|
||||
@@ -506,6 +511,17 @@ const verifyToken = async (
|
||||
if (!existingOlm || !existingOlm[0]) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (olm.userId) { // this is a user device and we need to check the user token
|
||||
const { session: userSession, user } = await validateSessionToken(userToken);
|
||||
if (!userSession || !user) {
|
||||
return null;
|
||||
}
|
||||
if (user.userId !== olm.userId) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return { client: existingOlm[0], session, clientType };
|
||||
} else if (clientType === "remoteExitNode") {
|
||||
const { session, remoteExitNode } =
|
||||
@@ -652,6 +668,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
|
||||
url.searchParams.get("token") ||
|
||||
request.headers["sec-websocket-protocol"] ||
|
||||
"";
|
||||
const userToken = url.searchParams.get('userToken') || '';
|
||||
let clientType = url.searchParams.get(
|
||||
"clientType"
|
||||
) as ClientType;
|
||||
@@ -673,7 +690,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
|
||||
return;
|
||||
}
|
||||
|
||||
const tokenPayload = await verifyToken(token, clientType);
|
||||
const tokenPayload = await verifyToken(token, clientType, userToken);
|
||||
if (!tokenPayload) {
|
||||
logger.debug(
|
||||
"Unauthorized connection attempt: invalid token..."
|
||||
@@ -792,6 +809,28 @@ if (redisManager.isRedisEnabled()) {
|
||||
);
|
||||
}
|
||||
|
||||
// Disconnect a specific client and force them to reconnect
|
||||
const disconnectClient = async (clientId: string): Promise<boolean> => {
|
||||
const mapKey = getClientMapKey(clientId);
|
||||
const clients = connectedClients.get(mapKey);
|
||||
|
||||
if (!clients || clients.length === 0) {
|
||||
logger.debug(`No connections found for client ID: ${clientId}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
logger.info(`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`);
|
||||
|
||||
// Close all connections for this client
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.close(1000, "Disconnected by server");
|
||||
}
|
||||
});
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
// Cleanup function for graceful shutdown
|
||||
const cleanup = async (): Promise<void> => {
|
||||
try {
|
||||
@@ -829,6 +868,7 @@ export {
|
||||
connectedClients,
|
||||
hasActiveConnections,
|
||||
getActiveNodes,
|
||||
disconnectClient,
|
||||
NODE_ID,
|
||||
cleanup
|
||||
};
|
||||
|
||||
@@ -131,7 +131,7 @@ export function queryRequest(data: Q) {
|
||||
eq(requestAuditLog.resourceId, resources.resourceId)
|
||||
) // TODO: Is this efficient?
|
||||
.where(getWhere(data))
|
||||
.orderBy(desc(requestAuditLog.timestamp), desc(requestAuditLog.id));
|
||||
.orderBy(desc(requestAuditLog.timestamp));
|
||||
}
|
||||
|
||||
export function countRequestQuery(data: Q) {
|
||||
|
||||
@@ -13,4 +13,7 @@ export * from "./initialSetupComplete";
|
||||
export * from "./validateSetupToken";
|
||||
export * from "./changePassword";
|
||||
export * from "./checkResourceSession";
|
||||
export * from "./securityKey";
|
||||
export * from "./securityKey";
|
||||
export * from "./startDeviceWebAuth";
|
||||
export * from "./verifyDeviceWebAuth";
|
||||
export * from "./pollDeviceWebAuth";
|
||||
@@ -1,7 +1,9 @@
|
||||
import {
|
||||
createSession,
|
||||
generateSessionToken,
|
||||
serializeSessionCookie
|
||||
invalidateSession,
|
||||
serializeSessionCookie,
|
||||
SESSION_COOKIE_NAME
|
||||
} from "@server/auth/sessions/app";
|
||||
import { db, resources } from "@server/db";
|
||||
import { users, securityKeys } from "@server/db";
|
||||
@@ -21,11 +23,11 @@ import { UserType } from "@server/types/UserTypes";
|
||||
import { logAccessAudit } from "#dynamic/lib/logAccessAudit";
|
||||
|
||||
export const loginBodySchema = z.strictObject({
|
||||
email: z.email().toLowerCase(),
|
||||
password: z.string(),
|
||||
code: z.string().optional(),
|
||||
resourceGuid: z.string().optional()
|
||||
});
|
||||
email: z.email().toLowerCase(),
|
||||
password: z.string(),
|
||||
code: z.string().optional(),
|
||||
resourceGuid: z.string().optional()
|
||||
});
|
||||
|
||||
export type LoginBody = z.infer<typeof loginBodySchema>;
|
||||
|
||||
@@ -41,6 +43,21 @@ export async function login(
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
const { forceLogin } = req.query;
|
||||
const { session: existingSession } = await verifySession(
|
||||
req,
|
||||
forceLogin === "true"
|
||||
);
|
||||
if (existingSession) {
|
||||
return response<null>(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Already logged in",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
const parsedBody = loginBodySchema.safeParse(req.body);
|
||||
|
||||
if (!parsedBody.success) {
|
||||
@@ -55,17 +72,6 @@ export async function login(
|
||||
const { email, password, code, resourceGuid } = parsedBody.data;
|
||||
|
||||
try {
|
||||
const { session: existingSession } = await verifySession(req);
|
||||
if (existingSession) {
|
||||
return response<null>(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Already logged in",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
let resourceId: number | null = null;
|
||||
let orgId: string | null = null;
|
||||
if (resourceGuid) {
|
||||
@@ -225,6 +231,12 @@ export async function login(
|
||||
}
|
||||
}
|
||||
|
||||
// check for previous cookie value and expire it
|
||||
const previousCookie = req.cookies[SESSION_COOKIE_NAME];
|
||||
if (previousCookie) {
|
||||
await invalidateSession(previousCookie);
|
||||
}
|
||||
|
||||
const token = generateSessionToken();
|
||||
const sess = await createSession(token, existingUser.userId);
|
||||
const isSecure = req.protocol === "https";
|
||||
|
||||
168
server/routers/auth/pollDeviceWebAuth.ts
Normal file
168
server/routers/auth/pollDeviceWebAuth.ts
Normal file
@@ -0,0 +1,168 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import logger from "@server/logger";
|
||||
import { response } from "@server/lib/response";
|
||||
import { db, deviceWebAuthCodes } from "@server/db";
|
||||
import { eq, and, gt } from "drizzle-orm";
|
||||
import {
|
||||
createSession,
|
||||
generateSessionToken
|
||||
} from "@server/auth/sessions/app";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
|
||||
const paramsSchema = z.object({
|
||||
code: z.string().min(1, "Code is required")
|
||||
});
|
||||
|
||||
export type PollDeviceWebAuthParams = z.infer<typeof paramsSchema>;
|
||||
|
||||
// Helper function to hash device code before querying database
|
||||
function hashDeviceCode(code: string): string {
|
||||
return encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(code))
|
||||
);
|
||||
}
|
||||
|
||||
export type PollDeviceWebAuthResponse = {
|
||||
verified: boolean;
|
||||
token?: string;
|
||||
};
|
||||
|
||||
// Helper function to extract IP from request (same as in startDeviceWebAuth)
|
||||
function extractIpFromRequest(req: Request): string | undefined {
|
||||
const ip = req.ip || req.socket.remoteAddress;
|
||||
if (!ip) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Handle IPv6 format [::1] or IPv4 format
|
||||
if (ip.startsWith("[") && ip.includes("]")) {
|
||||
const ipv6Match = ip.match(/\[(.*?)\]/);
|
||||
if (ipv6Match) {
|
||||
return ipv6Match[1];
|
||||
}
|
||||
}
|
||||
|
||||
// Handle IPv4 with port (split at last colon)
|
||||
const lastColonIndex = ip.lastIndexOf(":");
|
||||
if (lastColonIndex !== -1) {
|
||||
return ip.substring(0, lastColonIndex);
|
||||
}
|
||||
|
||||
return ip;
|
||||
}
|
||||
|
||||
export async function pollDeviceWebAuth(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const { code } = parsedParams.data;
|
||||
const now = Date.now();
|
||||
const requestIp = extractIpFromRequest(req);
|
||||
|
||||
// Hash the code before querying
|
||||
const hashedCode = hashDeviceCode(code);
|
||||
|
||||
// Find the code in the database
|
||||
const [deviceCode] = await db
|
||||
.select()
|
||||
.from(deviceWebAuthCodes)
|
||||
.where(eq(deviceWebAuthCodes.code, hashedCode))
|
||||
.limit(1);
|
||||
|
||||
if (!deviceCode) {
|
||||
return response<PollDeviceWebAuthResponse>(res, {
|
||||
data: {
|
||||
verified: false
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Code not found",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
// Check if code is expired
|
||||
if (deviceCode.expiresAt <= now) {
|
||||
return response<PollDeviceWebAuthResponse>(res, {
|
||||
data: {
|
||||
verified: false
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Code expired",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
// Check if code is verified
|
||||
if (!deviceCode.verified) {
|
||||
return response<PollDeviceWebAuthResponse>(res, {
|
||||
data: {
|
||||
verified: false
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Code not yet verified",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
// Check if userId is set (should be set when verified)
|
||||
if (!deviceCode.userId) {
|
||||
logger.error("Device code is verified but userId is missing", { codeId: deviceCode.codeId });
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Invalid code state"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Generate session token
|
||||
const token = generateSessionToken();
|
||||
await createSession(token, deviceCode.userId);
|
||||
|
||||
// Delete the code after successful exchange for a token
|
||||
await db
|
||||
.delete(deviceWebAuthCodes)
|
||||
.where(eq(deviceWebAuthCodes.codeId, deviceCode.codeId));
|
||||
|
||||
return response<PollDeviceWebAuthResponse>(res, {
|
||||
data: {
|
||||
verified: true,
|
||||
token
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Code verified and session created",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
logger.error(e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to poll device code"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ setInterval(async () => {
|
||||
await db
|
||||
.delete(webauthnChallenge)
|
||||
.where(lt(webauthnChallenge.expiresAt, now));
|
||||
logger.debug("Cleaned up expired security key challenges");
|
||||
// logger.debug("Cleaned up expired security key challenges");
|
||||
} catch (error) {
|
||||
logger.error("Failed to clean up expired security key challenges", error);
|
||||
}
|
||||
|
||||
156
server/routers/auth/startDeviceWebAuth.ts
Normal file
156
server/routers/auth/startDeviceWebAuth.ts
Normal file
@@ -0,0 +1,156 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import logger from "@server/logger";
|
||||
import { response } from "@server/lib/response";
|
||||
import { db, deviceWebAuthCodes } from "@server/db";
|
||||
import { alphabet, generateRandomString } from "oslo/crypto";
|
||||
import { createDate } from "oslo";
|
||||
import { TimeSpan } from "oslo";
|
||||
import { maxmindLookup } from "@server/db/maxmind";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
|
||||
const bodySchema = z.object({
|
||||
deviceName: z.string().optional(),
|
||||
applicationName: z.string().min(1, "Application name is required")
|
||||
}).strict();
|
||||
|
||||
export type StartDeviceWebAuthBody = z.infer<typeof bodySchema>;
|
||||
|
||||
export type StartDeviceWebAuthResponse = {
|
||||
code: string;
|
||||
expiresInSeconds: number;
|
||||
};
|
||||
|
||||
// Helper function to generate device code in format A1AJ-N5JD
|
||||
function generateDeviceCode(): string {
|
||||
const part1 = generateRandomString(4, alphabet("A-Z", "0-9"));
|
||||
const part2 = generateRandomString(4, alphabet("A-Z", "0-9"));
|
||||
return `${part1}-${part2}`;
|
||||
}
|
||||
|
||||
// Helper function to hash device code before storing in database
|
||||
function hashDeviceCode(code: string): string {
|
||||
return encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(code))
|
||||
);
|
||||
}
|
||||
|
||||
// Helper function to extract IP from request
|
||||
function extractIpFromRequest(req: Request): string | undefined {
|
||||
const ip = req.ip || req.socket.remoteAddress;
|
||||
if (!ip) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Handle IPv6 format [::1] or IPv4 format
|
||||
if (ip.startsWith("[") && ip.includes("]")) {
|
||||
const ipv6Match = ip.match(/\[(.*?)\]/);
|
||||
if (ipv6Match) {
|
||||
return ipv6Match[1];
|
||||
}
|
||||
}
|
||||
|
||||
// Handle IPv4 with port (split at last colon)
|
||||
const lastColonIndex = ip.lastIndexOf(":");
|
||||
if (lastColonIndex !== -1) {
|
||||
return ip.substring(0, lastColonIndex);
|
||||
}
|
||||
|
||||
return ip;
|
||||
}
|
||||
|
||||
// Helper function to get city from IP (if available)
|
||||
async function getCityFromIp(ip: string): Promise<string | undefined> {
|
||||
try {
|
||||
if (!maxmindLookup) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const result = maxmindLookup.get(ip);
|
||||
if (!result) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// MaxMind CountryResponse doesn't include city by default
|
||||
// If city data is available, it would be in result.city?.names?.en
|
||||
// But since we're using CountryResponse type, we'll just return undefined
|
||||
// The user said "don't do this if not easy", so we'll skip city for now
|
||||
return undefined;
|
||||
} catch (error) {
|
||||
logger.debug("Failed to get city from IP", error);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export async function startDeviceWebAuth(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const { deviceName, applicationName } = parsedBody.data;
|
||||
|
||||
// Generate device code
|
||||
const code = generateDeviceCode();
|
||||
|
||||
// Hash the code before storing in database
|
||||
const hashedCode = hashDeviceCode(code);
|
||||
|
||||
// Extract IP from request
|
||||
const ip = extractIpFromRequest(req);
|
||||
|
||||
// Get city (optional, may return undefined)
|
||||
const city = ip ? await getCityFromIp(ip) : undefined;
|
||||
|
||||
// Set expiration to 5 minutes from now
|
||||
const expiresAt = createDate(new TimeSpan(5, "m")).getTime();
|
||||
|
||||
// Insert into database (store hashed code)
|
||||
await db.insert(deviceWebAuthCodes).values({
|
||||
code: hashedCode,
|
||||
ip: ip || null,
|
||||
city: city || null,
|
||||
deviceName: deviceName || null,
|
||||
applicationName,
|
||||
expiresAt,
|
||||
createdAt: Date.now()
|
||||
});
|
||||
|
||||
// calculate relative expiration in seconds
|
||||
const expiresInSeconds = Math.floor((expiresAt - Date.now()) / 1000);
|
||||
|
||||
return response<StartDeviceWebAuthResponse>(res, {
|
||||
data: {
|
||||
code,
|
||||
expiresInSeconds
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Device web auth code generated",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
logger.error(e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to start device web auth"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
180
server/routers/auth/verifyDeviceWebAuth.ts
Normal file
180
server/routers/auth/verifyDeviceWebAuth.ts
Normal file
@@ -0,0 +1,180 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import logger from "@server/logger";
|
||||
import { response } from "@server/lib/response";
|
||||
import { db, deviceWebAuthCodes, sessions } from "@server/db";
|
||||
import { eq, and, gt } from "drizzle-orm";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { unauthorized } from "@server/auth/unauthorizedResponse";
|
||||
|
||||
const bodySchema = z
|
||||
.object({
|
||||
code: z.string().min(1, "Code is required"),
|
||||
verify: z.boolean().optional().default(false) // If false, just check and return metadata
|
||||
})
|
||||
.strict();
|
||||
|
||||
// Helper function to hash device code before querying database
|
||||
function hashDeviceCode(code: string): string {
|
||||
return encodeHexLowerCase(sha256(new TextEncoder().encode(code)));
|
||||
}
|
||||
|
||||
export type VerifyDeviceWebAuthBody = z.infer<typeof bodySchema>;
|
||||
|
||||
export type VerifyDeviceWebAuthResponse = {
|
||||
success: boolean;
|
||||
message: string;
|
||||
metadata?: {
|
||||
ip: string | null;
|
||||
city: string | null;
|
||||
deviceName: string | null;
|
||||
applicationName: string;
|
||||
createdAt: number;
|
||||
};
|
||||
};
|
||||
|
||||
export async function verifyDeviceWebAuth(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
const { user, session } = req;
|
||||
if (!user || !session) {
|
||||
return next(createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized"));
|
||||
}
|
||||
|
||||
if (session.deviceAuthUsed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
"Device web auth code already used for this session"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!session.issuedAt) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
"Session issuedAt timestamp missing"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// make sure sessions is not older than 5 minutes
|
||||
const now = Date.now();
|
||||
if (now - session.issuedAt > 5 * 60 * 1000) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
"Session is too old to verify device web auth code"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const { code, verify } = parsedBody.data;
|
||||
const now = Date.now();
|
||||
|
||||
logger.debug("Verifying device web auth code:", { code });
|
||||
|
||||
// Hash the code before querying
|
||||
const hashedCode = hashDeviceCode(code);
|
||||
|
||||
// Find the code in the database that is not expired and not already verified
|
||||
const [deviceCode] = await db
|
||||
.select()
|
||||
.from(deviceWebAuthCodes)
|
||||
.where(
|
||||
and(
|
||||
eq(deviceWebAuthCodes.code, hashedCode),
|
||||
gt(deviceWebAuthCodes.expiresAt, now),
|
||||
eq(deviceWebAuthCodes.verified, false)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
logger.debug("Device code lookup result:", deviceCode);
|
||||
|
||||
if (!deviceCode) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Invalid, expired, or already verified code"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// If verify is false, just return metadata without verifying
|
||||
if (!verify) {
|
||||
return response<VerifyDeviceWebAuthResponse>(res, {
|
||||
data: {
|
||||
success: true,
|
||||
message: "Code is valid",
|
||||
metadata: {
|
||||
ip: deviceCode.ip,
|
||||
city: deviceCode.city,
|
||||
deviceName: deviceCode.deviceName,
|
||||
applicationName: deviceCode.applicationName,
|
||||
createdAt: deviceCode.createdAt
|
||||
}
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Code validation successful",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
// Update the code to mark it as verified and store the user who verified it
|
||||
await db
|
||||
.update(deviceWebAuthCodes)
|
||||
.set({
|
||||
verified: true,
|
||||
userId: req.user!.userId
|
||||
})
|
||||
.where(eq(deviceWebAuthCodes.codeId, deviceCode.codeId));
|
||||
|
||||
// Also update the session to mark that device auth was used
|
||||
await db
|
||||
.update(sessions)
|
||||
.set({
|
||||
deviceAuthUsed: true
|
||||
})
|
||||
.where(eq(sessions.sessionId, session.sessionId));
|
||||
|
||||
return response<VerifyDeviceWebAuthResponse>(res, {
|
||||
data: {
|
||||
success: true,
|
||||
message: "Device code verified successfully"
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Device code verified successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
logger.error(e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to verify device code"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -69,9 +69,9 @@ export async function cleanUpOldLogs(orgId: string, retentionDays: number) {
|
||||
)
|
||||
);
|
||||
|
||||
logger.debug(
|
||||
`Cleaned up request audit logs older than ${retentionDays} days`
|
||||
);
|
||||
// logger.debug(
|
||||
// `Cleaned up request audit logs older than ${retentionDays} days`
|
||||
// );
|
||||
} catch (error) {
|
||||
logger.error("Error cleaning up old request audit logs:", error);
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import {
|
||||
roleClients,
|
||||
userClients,
|
||||
olms,
|
||||
clientSites,
|
||||
exitNodes,
|
||||
orgs,
|
||||
sites
|
||||
} from "@server/db";
|
||||
@@ -21,23 +19,24 @@ import { eq, and } from "drizzle-orm";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import moment from "moment";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import { isValidCIDR, isValidIP } from "@server/lib/validators";
|
||||
import { isValidIP } from "@server/lib/validators";
|
||||
import { isIpInCidr } from "@server/lib/ip";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { listExitNodes } from "#dynamic/lib/exitNodes";
|
||||
import { generateId } from "@server/auth/sessions/app";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
|
||||
|
||||
const createClientParamsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const createClientSchema = z.strictObject({
|
||||
name: z.string().min(1).max(255),
|
||||
siteIds: z.array(z.int().positive()),
|
||||
olmId: z.string(),
|
||||
secret: z.string(),
|
||||
subnet: z.string(),
|
||||
type: z.enum(["olm"])
|
||||
});
|
||||
name: z.string().min(1).max(255),
|
||||
olmId: z.string(),
|
||||
secret: z.string(),
|
||||
subnet: z.string(),
|
||||
type: z.enum(["olm"])
|
||||
});
|
||||
|
||||
export type CreateClientBody = z.infer<typeof createClientSchema>;
|
||||
|
||||
@@ -46,7 +45,7 @@ export type CreateClientResponse = Client;
|
||||
registry.registerPath({
|
||||
method: "put",
|
||||
path: "/org/{orgId}/client",
|
||||
description: "Create a new client.",
|
||||
description: "Create a new client for an organization.",
|
||||
tags: [OpenAPITags.Client, OpenAPITags.Org],
|
||||
request: {
|
||||
params: createClientParamsSchema,
|
||||
@@ -77,7 +76,7 @@ export async function createClient(
|
||||
);
|
||||
}
|
||||
|
||||
const { name, type, siteIds, olmId, secret, subnet } = parsedBody.data;
|
||||
const { name, type, olmId, secret, subnet } = parsedBody.data;
|
||||
|
||||
const parsedParams = createClientParamsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
@@ -172,75 +171,90 @@ export async function createClient(
|
||||
);
|
||||
}
|
||||
|
||||
// check if the olmId already exists
|
||||
const [existingOlm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.olmId, olmId))
|
||||
.limit(1);
|
||||
|
||||
if (existingOlm) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
`OLM with ID ${olmId} already exists`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let newClient: Client | null = null;
|
||||
await db.transaction(async (trx) => {
|
||||
// TODO: more intelligent way to pick the exit node
|
||||
const exitNodesList = await listExitNodes(orgId);
|
||||
const randomExitNode =
|
||||
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
|
||||
|
||||
const adminRole = await trx
|
||||
const [adminRole] = await trx
|
||||
.select()
|
||||
.from(roles)
|
||||
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
|
||||
.limit(1);
|
||||
|
||||
if (adminRole.length === 0) {
|
||||
trx.rollback();
|
||||
if (!adminRole) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
|
||||
);
|
||||
}
|
||||
|
||||
const [newClient] = await trx
|
||||
[newClient] = await trx
|
||||
.insert(clients)
|
||||
.values({
|
||||
exitNodeId: randomExitNode.exitNodeId,
|
||||
orgId,
|
||||
name,
|
||||
subnet: updatedSubnet,
|
||||
type
|
||||
type,
|
||||
olmId // this is to lock it to a specific olm even if the olm moves across clients
|
||||
})
|
||||
.returning();
|
||||
|
||||
await trx.insert(roleClients).values({
|
||||
roleId: adminRole[0].roleId,
|
||||
roleId: adminRole.roleId,
|
||||
clientId: newClient.clientId
|
||||
});
|
||||
|
||||
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
|
||||
// make sure the user can access the site
|
||||
if (req.user && req.userOrgRoleId != adminRole.roleId) {
|
||||
// make sure the user can access the client
|
||||
trx.insert(userClients).values({
|
||||
userId: req.user?.userId!,
|
||||
userId: req.user.userId,
|
||||
clientId: newClient.clientId
|
||||
});
|
||||
}
|
||||
|
||||
// Create site to client associations
|
||||
if (siteIds && siteIds.length > 0) {
|
||||
await trx.insert(clientSites).values(
|
||||
siteIds.map((siteId) => ({
|
||||
clientId: newClient.clientId,
|
||||
siteId
|
||||
}))
|
||||
);
|
||||
let secretToUse = secret;
|
||||
if (!secretToUse) {
|
||||
secretToUse = generateId(48);
|
||||
}
|
||||
|
||||
const secretHash = await hashPassword(secret);
|
||||
const secretHash = await hashPassword(secretToUse);
|
||||
|
||||
await trx.insert(olms).values({
|
||||
olmId,
|
||||
secretHash,
|
||||
name,
|
||||
clientId: newClient.clientId,
|
||||
dateCreated: moment().toISOString()
|
||||
});
|
||||
|
||||
return response<CreateClientResponse>(res, {
|
||||
data: newClient,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Site created successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
await rebuildClientAssociationsFromClient(newClient, trx);
|
||||
});
|
||||
|
||||
return response<CreateClientResponse>(res, {
|
||||
data: newClient,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Site created successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
|
||||
253
server/routers/client/createUserClient.ts
Normal file
253
server/routers/client/createUserClient.ts
Normal file
@@ -0,0 +1,253 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import {
|
||||
roles,
|
||||
Client,
|
||||
clients,
|
||||
roleClients,
|
||||
userClients,
|
||||
olms,
|
||||
orgs,
|
||||
sites
|
||||
} from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { isValidIP } from "@server/lib/validators";
|
||||
import { isIpInCidr } from "@server/lib/ip";
|
||||
import { listExitNodes } from "#dynamic/lib/exitNodes";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
orgId: z.string(),
|
||||
userId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
const bodySchema = z
|
||||
.object({
|
||||
name: z.string().min(1).max(255),
|
||||
olmId: z.string(),
|
||||
subnet: z.string(),
|
||||
type: z.enum(["olm"])
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type CreateClientAndOlmBody = z.infer<typeof bodySchema>;
|
||||
|
||||
export type CreateClientAndOlmResponse = Client;
|
||||
|
||||
registry.registerPath({
|
||||
method: "put",
|
||||
path: "/org/{orgId}/user/{userId}/client",
|
||||
description:
|
||||
"Create a new client for a user and associate it with an existing olm.",
|
||||
tags: [OpenAPITags.Client, OpenAPITags.Org, OpenAPITags.User],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
body: {
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: bodySchema
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function createUserClient(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { name, type, olmId, subnet } = parsedBody.data;
|
||||
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId, userId } = parsedParams.data;
|
||||
|
||||
if (!isValidIP(subnet)) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Invalid subnet format. Please provide a valid CIDR notation."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
|
||||
|
||||
if (!org) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Organization with ID ${orgId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!org.subnet) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Organization with ID ${orgId} has no subnet defined`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!isIpInCidr(subnet, org.subnet)) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"IP is not in the CIDR range of the subnet."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
|
||||
|
||||
// make sure the subnet is unique
|
||||
const subnetExistsClients = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(
|
||||
and(eq(clients.subnet, updatedSubnet), eq(clients.orgId, orgId))
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (subnetExistsClients.length > 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
`Subnet ${updatedSubnet} already exists in clients`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const subnetExistsSites = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(
|
||||
and(eq(sites.address, updatedSubnet), eq(sites.orgId, orgId))
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (subnetExistsSites.length > 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
`Subnet ${updatedSubnet} already exists in sites`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// check if the olmId already exists
|
||||
const [existingOlm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.olmId, olmId))
|
||||
.limit(1);
|
||||
|
||||
if (!existingOlm) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`OLM with ID ${olmId} does not exist`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (existingOlm.userId !== userId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`OLM with ID ${olmId} does not belong to user with ID ${userId}`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let newClient: Client | null = null;
|
||||
await db.transaction(async (trx) => {
|
||||
// TODO: more intelligent way to pick the exit node
|
||||
const exitNodesList = await listExitNodes(orgId);
|
||||
const randomExitNode =
|
||||
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
|
||||
|
||||
const [adminRole] = await trx
|
||||
.select()
|
||||
.from(roles)
|
||||
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
|
||||
.limit(1);
|
||||
|
||||
if (!adminRole) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
|
||||
);
|
||||
}
|
||||
|
||||
[newClient] = await trx
|
||||
.insert(clients)
|
||||
.values({
|
||||
exitNodeId: randomExitNode.exitNodeId,
|
||||
orgId,
|
||||
name,
|
||||
subnet: updatedSubnet,
|
||||
type,
|
||||
olmId, // this is to lock it to a specific olm even if the olm moves across clients
|
||||
userId
|
||||
})
|
||||
.returning();
|
||||
|
||||
await trx.insert(roleClients).values({
|
||||
roleId: adminRole.roleId,
|
||||
clientId: newClient.clientId
|
||||
});
|
||||
|
||||
trx.insert(userClients).values({
|
||||
userId,
|
||||
clientId: newClient.clientId
|
||||
});
|
||||
|
||||
await rebuildClientAssociationsFromClient(newClient, trx);
|
||||
});
|
||||
|
||||
return response<CreateClientAndOlmResponse>(res, {
|
||||
data: newClient,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Site created successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { clients, clientSites } from "@server/db";
|
||||
import { db, olms } from "@server/db";
|
||||
import { clients, clientSitesAssociationsCache } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
@@ -9,10 +9,12 @@ import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
|
||||
import { sendTerminateClient } from "./terminate";
|
||||
|
||||
const deleteClientSchema = z.strictObject({
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "delete",
|
||||
@@ -58,16 +60,38 @@ export async function deleteClient(
|
||||
);
|
||||
}
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
// Delete the client-site associations first
|
||||
await trx
|
||||
.delete(clientSites)
|
||||
.where(eq(clientSites.clientId, clientId));
|
||||
if (client.userId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Cannot delete a user client with this endpoint`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
// Then delete the client itself
|
||||
await trx
|
||||
const [deletedClient] = await trx
|
||||
.delete(clients)
|
||||
.where(eq(clients.clientId, clientId));
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.returning();
|
||||
|
||||
const [olm] = await trx
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
|
||||
// this is a machine client so we also delete the olm
|
||||
if (!client.userId && client.olmId) {
|
||||
await trx.delete(olms).where(eq(olms.olmId, client.olmId));
|
||||
}
|
||||
|
||||
await rebuildClientAssociationsFromClient(deletedClient, trx);
|
||||
|
||||
if (olm) {
|
||||
await sendTerminateClient(deletedClient.clientId, olm.olmId); // the olmId needs to be provided because it cant look it up after deletion
|
||||
}
|
||||
});
|
||||
|
||||
return response(res, {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { clients, clientSites } from "@server/db";
|
||||
import { clients, clientSitesAssociationsCache } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
@@ -29,9 +29,9 @@ async function query(clientId: number) {
|
||||
|
||||
// Get the siteIds associated with this client
|
||||
const sites = await db
|
||||
.select({ siteId: clientSites.siteId })
|
||||
.from(clientSites)
|
||||
.where(eq(clientSites.clientId, clientId));
|
||||
.select({ siteId: clientSitesAssociationsCache.siteId })
|
||||
.from(clientSitesAssociationsCache)
|
||||
.where(eq(clientSitesAssociationsCache.clientId, clientId));
|
||||
|
||||
// Add the siteIds to the client object
|
||||
return {
|
||||
|
||||
@@ -3,4 +3,5 @@ export * from "./createClient";
|
||||
export * from "./deleteClient";
|
||||
export * from "./listClients";
|
||||
export * from "./updateClient";
|
||||
export * from "./getClient";
|
||||
export * from "./getClient";
|
||||
export * from "./createUserClient";
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import { db, olms } from "@server/db";
|
||||
import { db, olms, users } from "@server/db";
|
||||
import {
|
||||
clients,
|
||||
orgs,
|
||||
roleClients,
|
||||
sites,
|
||||
userClients,
|
||||
clientSites
|
||||
clientSitesAssociationsCache
|
||||
} from "@server/db";
|
||||
import logger from "@server/logger";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import response from "@server/lib/response";
|
||||
import { and, count, eq, inArray, or, sql } from "drizzle-orm";
|
||||
import { and, count, eq, inArray, isNotNull, isNull, or, sql } from "drizzle-orm";
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
@@ -19,7 +19,7 @@ import { OpenAPITags, registry } from "@server/openApi";
|
||||
import NodeCache from "node-cache";
|
||||
import semver from "semver";
|
||||
|
||||
const olmVersionCache = new NodeCache({ stdTTL: 3600 });
|
||||
const olmVersionCache = new NodeCache({ stdTTL: 3600 });
|
||||
|
||||
async function getLatestOlmVersion(): Promise<string | null> {
|
||||
try {
|
||||
@@ -29,7 +29,7 @@ async function getLatestOlmVersion(): Promise<string | null> {
|
||||
}
|
||||
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => controller.abort(), 1500);
|
||||
const timeoutId = setTimeout(() => controller.abort(), 1500);
|
||||
|
||||
const response = await fetch(
|
||||
"https://api.github.com/repos/fosrl/olm/tags",
|
||||
@@ -94,10 +94,25 @@ const listClientsSchema = z.object({
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative())
|
||||
.pipe(z.int().nonnegative()),
|
||||
filter: z
|
||||
.enum(["user", "machine"])
|
||||
.optional()
|
||||
});
|
||||
|
||||
function queryClients(orgId: string, accessibleClientIds: number[]) {
|
||||
function queryClients(orgId: string, accessibleClientIds: number[], filter?: "user" | "machine") {
|
||||
const conditions = [
|
||||
inArray(clients.clientId, accessibleClientIds),
|
||||
eq(clients.orgId, orgId)
|
||||
];
|
||||
|
||||
// Add filter condition based on filter type
|
||||
if (filter === "user") {
|
||||
conditions.push(isNotNull(clients.userId));
|
||||
} else if (filter === "machine") {
|
||||
conditions.push(isNull(clients.userId));
|
||||
}
|
||||
|
||||
return db
|
||||
.select({
|
||||
clientId: clients.clientId,
|
||||
@@ -110,17 +125,16 @@ function queryClients(orgId: string, accessibleClientIds: number[]) {
|
||||
orgName: orgs.name,
|
||||
type: clients.type,
|
||||
online: clients.online,
|
||||
olmVersion: olms.version
|
||||
olmVersion: olms.version,
|
||||
userId: clients.userId,
|
||||
username: users.username,
|
||||
userEmail: users.email
|
||||
})
|
||||
.from(clients)
|
||||
.leftJoin(orgs, eq(clients.orgId, orgs.orgId))
|
||||
.leftJoin(olms, eq(clients.clientId, olms.clientId))
|
||||
.where(
|
||||
and(
|
||||
inArray(clients.clientId, accessibleClientIds),
|
||||
eq(clients.orgId, orgId)
|
||||
)
|
||||
);
|
||||
.leftJoin(users, eq(clients.userId, users.userId))
|
||||
.where(and(...conditions));
|
||||
}
|
||||
|
||||
async function getSiteAssociations(clientIds: number[]) {
|
||||
@@ -128,14 +142,14 @@ async function getSiteAssociations(clientIds: number[]) {
|
||||
|
||||
return db
|
||||
.select({
|
||||
clientId: clientSites.clientId,
|
||||
siteId: clientSites.siteId,
|
||||
clientId: clientSitesAssociationsCache.clientId,
|
||||
siteId: clientSitesAssociationsCache.siteId,
|
||||
siteName: sites.name,
|
||||
siteNiceId: sites.niceId
|
||||
})
|
||||
.from(clientSites)
|
||||
.leftJoin(sites, eq(clientSites.siteId, sites.siteId))
|
||||
.where(inArray(clientSites.clientId, clientIds));
|
||||
.from(clientSitesAssociationsCache)
|
||||
.leftJoin(sites, eq(clientSitesAssociationsCache.siteId, sites.siteId))
|
||||
.where(inArray(clientSitesAssociationsCache.clientId, clientIds));
|
||||
}
|
||||
|
||||
type OlmWithUpdateAvailable = Awaited<ReturnType<typeof queryClients>>[0] & {
|
||||
@@ -182,7 +196,7 @@ export async function listClients(
|
||||
)
|
||||
);
|
||||
}
|
||||
const { limit, offset } = parsedQuery.data;
|
||||
const { limit, offset, filter } = parsedQuery.data;
|
||||
|
||||
const parsedParams = listClientsParamsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
@@ -231,18 +245,24 @@ export async function listClients(
|
||||
const accessibleClientIds = accessibleClients.map(
|
||||
(client) => client.clientId
|
||||
);
|
||||
const baseQuery = queryClients(orgId, accessibleClientIds);
|
||||
const baseQuery = queryClients(orgId, accessibleClientIds, filter);
|
||||
|
||||
// Get client count with filter
|
||||
const countConditions = [
|
||||
inArray(clients.clientId, accessibleClientIds),
|
||||
eq(clients.orgId, orgId)
|
||||
];
|
||||
|
||||
if (filter === "user") {
|
||||
countConditions.push(isNotNull(clients.userId));
|
||||
} else if (filter === "machine") {
|
||||
countConditions.push(isNull(clients.userId));
|
||||
}
|
||||
|
||||
// Get client count
|
||||
const countQuery = db
|
||||
.select({ count: count() })
|
||||
.from(clients)
|
||||
.where(
|
||||
and(
|
||||
inArray(clients.clientId, accessibleClientIds),
|
||||
eq(clients.orgId, orgId)
|
||||
)
|
||||
);
|
||||
.where(and(...countConditions));
|
||||
|
||||
const clientsList = await baseQuery.limit(limit).offset(offset);
|
||||
const totalCountResult = await countQuery;
|
||||
|
||||
@@ -1,35 +1,136 @@
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import { db, olms } from "@server/db";
|
||||
import { Alias, SubnetProxyTarget } from "@server/lib/ip";
|
||||
import logger from "@server/logger";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
export async function addTargets(
|
||||
newtId: string,
|
||||
destinationIp: string,
|
||||
destinationPort: number,
|
||||
protocol: string,
|
||||
port: number
|
||||
) {
|
||||
const target = `${port}:${destinationIp}:${destinationPort}`;
|
||||
|
||||
export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
|
||||
await sendToClient(newtId, {
|
||||
type: `newt/wg/${protocol}/add`,
|
||||
data: {
|
||||
targets: [target] // We can only use one target for WireGuard right now
|
||||
}
|
||||
type: `newt/wg/targets/add`,
|
||||
data: targets
|
||||
});
|
||||
}
|
||||
|
||||
export async function removeTargets(
|
||||
newtId: string,
|
||||
destinationIp: string,
|
||||
destinationPort: number,
|
||||
protocol: string,
|
||||
port: number
|
||||
targets: SubnetProxyTarget[]
|
||||
) {
|
||||
const target = `${port}:${destinationIp}:${destinationPort}`;
|
||||
|
||||
await sendToClient(newtId, {
|
||||
type: `newt/wg/${protocol}/remove`,
|
||||
data: {
|
||||
targets: [target] // We can only use one target for WireGuard right now
|
||||
}
|
||||
type: `newt/wg/targets/remove`,
|
||||
data: targets
|
||||
});
|
||||
}
|
||||
|
||||
export async function updateTargets(
|
||||
newtId: string,
|
||||
targets: {
|
||||
oldTargets: SubnetProxyTarget[];
|
||||
newTargets: SubnetProxyTarget[];
|
||||
}
|
||||
) {
|
||||
await sendToClient(newtId, {
|
||||
type: `newt/wg/targets/update`,
|
||||
data: targets
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
|
||||
export async function addPeerData(
|
||||
clientId: number,
|
||||
siteId: number,
|
||||
remoteSubnets: string[],
|
||||
aliases: Alias[],
|
||||
olmId?: string
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
if (!olm) {
|
||||
return; // ignore this because an olm might not be associated with the client anymore
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: `olm/wg/peer/data/add`,
|
||||
data: {
|
||||
siteId: siteId,
|
||||
remoteSubnets: remoteSubnets,
|
||||
aliases: aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
|
||||
export async function removePeerData(
|
||||
clientId: number,
|
||||
siteId: number,
|
||||
remoteSubnets: string[],
|
||||
aliases: Alias[],
|
||||
olmId?: string
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
if (!olm) {
|
||||
return;
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: `olm/wg/peer/data/remove`,
|
||||
data: {
|
||||
siteId: siteId,
|
||||
remoteSubnets: remoteSubnets,
|
||||
aliases: aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
|
||||
export async function updatePeerData(
|
||||
clientId: number,
|
||||
siteId: number,
|
||||
remoteSubnets: {
|
||||
oldRemoteSubnets: string[];
|
||||
newRemoteSubnets: string[];
|
||||
} | undefined,
|
||||
aliases: {
|
||||
oldAliases: Alias[];
|
||||
newAliases: Alias[];
|
||||
} | undefined,
|
||||
olmId?: string
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
if (!olm) {
|
||||
return;
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: `olm/wg/peer/data/update`,
|
||||
data: {
|
||||
siteId: siteId,
|
||||
...remoteSubnets,
|
||||
...aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
|
||||
22
server/routers/client/terminate.ts
Normal file
22
server/routers/client/terminate.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import { db, olms } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
export async function sendTerminateClient(clientId: number, olmId?: string | null) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
if (!olm) {
|
||||
throw new Error(`Olm with ID ${clientId} not found`);
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: `olm/terminate`,
|
||||
data: {}
|
||||
});
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { Client, db, exitNodes, olms, sites } from "@server/db";
|
||||
import { clients, clientSites } from "@server/db";
|
||||
import { clients, clientSitesAssociationsCache } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -9,27 +9,14 @@ import logger from "@server/logger";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import {
|
||||
addPeer as newtAddPeer,
|
||||
deletePeer as newtDeletePeer
|
||||
} from "../newt/peers";
|
||||
import {
|
||||
addPeer as olmAddPeer,
|
||||
deletePeer as olmDeletePeer
|
||||
} from "../olm/peers";
|
||||
import { sendToExitNode } from "#dynamic/lib/exitNodes";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
|
||||
const updateClientParamsSchema = z.strictObject({
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
const updateClientSchema = z.strictObject({
|
||||
name: z.string().min(1).max(255).optional(),
|
||||
siteIds: z
|
||||
.array(z.int().positive())
|
||||
.optional(),
|
||||
});
|
||||
name: z.string().min(1).max(255).optional()
|
||||
});
|
||||
|
||||
export type UpdateClientBody = z.infer<typeof updateClientSchema>;
|
||||
|
||||
@@ -51,11 +38,6 @@ registry.registerPath({
|
||||
responses: {}
|
||||
});
|
||||
|
||||
interface PeerDestination {
|
||||
destinationIP: string;
|
||||
destinationPort: number;
|
||||
}
|
||||
|
||||
export async function updateClient(
|
||||
req: Request,
|
||||
res: Response,
|
||||
@@ -72,7 +54,7 @@ export async function updateClient(
|
||||
);
|
||||
}
|
||||
|
||||
const { name, siteIds } = parsedBody.data;
|
||||
const { name } = parsedBody.data;
|
||||
|
||||
const parsedParams = updateClientParamsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
@@ -86,7 +68,6 @@ export async function updateClient(
|
||||
|
||||
const { clientId } = parsedParams.data;
|
||||
|
||||
|
||||
// Fetch the client to make sure it exists and the user has access to it
|
||||
const [client] = await db
|
||||
.select()
|
||||
@@ -103,266 +84,11 @@ export async function updateClient(
|
||||
);
|
||||
}
|
||||
|
||||
let sitesAdded = [];
|
||||
let sitesRemoved = [];
|
||||
|
||||
// Fetch existing site associations
|
||||
const existingSites = await db
|
||||
.select({ siteId: clientSites.siteId })
|
||||
.from(clientSites)
|
||||
.where(eq(clientSites.clientId, clientId));
|
||||
|
||||
const existingSiteIds = existingSites.map((site) => site.siteId);
|
||||
|
||||
const siteIdsToProcess = siteIds || [];
|
||||
// Determine which sites were added and removed
|
||||
sitesAdded = siteIdsToProcess.filter(
|
||||
(siteId) => !existingSiteIds.includes(siteId)
|
||||
);
|
||||
sitesRemoved = existingSiteIds.filter(
|
||||
(siteId) => !siteIdsToProcess.includes(siteId)
|
||||
);
|
||||
|
||||
let updatedClient: Client | undefined = undefined;
|
||||
let sitesData: any; // TODO: define type somehow from the query below
|
||||
await db.transaction(async (trx) => {
|
||||
// Update client name if provided
|
||||
if (name) {
|
||||
await trx
|
||||
.update(clients)
|
||||
.set({ name })
|
||||
.where(eq(clients.clientId, clientId));
|
||||
}
|
||||
|
||||
// Update site associations if provided
|
||||
// Remove sites that are no longer associated
|
||||
for (const siteId of sitesRemoved) {
|
||||
await trx
|
||||
.delete(clientSites)
|
||||
.where(
|
||||
and(
|
||||
eq(clientSites.clientId, clientId),
|
||||
eq(clientSites.siteId, siteId)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Add new site associations
|
||||
for (const siteId of sitesAdded) {
|
||||
await trx.insert(clientSites).values({
|
||||
clientId,
|
||||
siteId
|
||||
});
|
||||
}
|
||||
|
||||
// Fetch the updated client
|
||||
[updatedClient] = await trx
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.limit(1);
|
||||
|
||||
// get all sites for this client and join with exit nodes with site.exitNodeId
|
||||
sitesData = await trx
|
||||
.select()
|
||||
.from(sites)
|
||||
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
|
||||
.leftJoin(exitNodes, eq(sites.exitNodeId, exitNodes.exitNodeId))
|
||||
.where(eq(clientSites.clientId, client.clientId));
|
||||
});
|
||||
|
||||
logger.info(
|
||||
`Adding ${sitesAdded.length} new sites to client ${client.clientId}`
|
||||
);
|
||||
for (const siteId of sitesAdded) {
|
||||
if (!client.subnet || !client.pubKey) {
|
||||
logger.debug("Client subnet, pubKey or endpoint is not set");
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: WE NEED TO HANDLE THIS BETTER. WE ARE DEFAULTING TO RELAYING FOR NEW SITES
|
||||
// BUT REALLY WE NEED TO TRACK THE USERS PREFERENCE THAT THEY CHOSE IN THE CLIENTS
|
||||
// AND TRIGGER A HOLEPUNCH OR SOMETHING TO GET THE ENDPOINT AND HP TO THE NEW SITES
|
||||
const isRelayed = true;
|
||||
|
||||
const site = await newtAddPeer(siteId, {
|
||||
publicKey: client.pubKey,
|
||||
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
|
||||
// endpoint: isRelayed ? "" : clientSite.endpoint
|
||||
endpoint: isRelayed ? "" : "" // we are not HPing yet so no endpoint
|
||||
});
|
||||
|
||||
if (!site) {
|
||||
logger.debug("Failed to add peer to newt - missing site");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!site.endpoint || !site.publicKey) {
|
||||
logger.debug("Site endpoint or publicKey is not set");
|
||||
continue;
|
||||
}
|
||||
|
||||
let endpoint;
|
||||
|
||||
if (isRelayed) {
|
||||
if (!site.exitNodeId) {
|
||||
logger.warn(
|
||||
`Site ${site.siteId} has no exit node, skipping`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
// get the exit node for the site
|
||||
const [exitNode] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
|
||||
.limit(1);
|
||||
|
||||
if (!exitNode) {
|
||||
logger.warn(`Exit node not found for site ${site.siteId}`);
|
||||
return null;
|
||||
}
|
||||
|
||||
endpoint = `${exitNode.endpoint}:21820`;
|
||||
} else {
|
||||
if (!site.endpoint) {
|
||||
logger.warn(
|
||||
`Site ${site.siteId} has no endpoint, skipping`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
endpoint = site.endpoint;
|
||||
}
|
||||
|
||||
await olmAddPeer(client.clientId, {
|
||||
siteId: site.siteId,
|
||||
endpoint: endpoint,
|
||||
publicKey: site.publicKey,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort,
|
||||
remoteSubnets: site.remoteSubnets
|
||||
});
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Removing ${sitesRemoved.length} sites from client ${client.clientId}`
|
||||
);
|
||||
for (const siteId of sitesRemoved) {
|
||||
if (!client.pubKey) {
|
||||
logger.debug("Client pubKey is not set");
|
||||
continue;
|
||||
}
|
||||
const site = await newtDeletePeer(siteId, client.pubKey);
|
||||
if (!site) {
|
||||
logger.debug("Failed to delete peer from newt - missing site");
|
||||
continue;
|
||||
}
|
||||
if (!site.endpoint || !site.publicKey) {
|
||||
logger.debug("Site endpoint or publicKey is not set");
|
||||
continue;
|
||||
}
|
||||
await olmDeletePeer(client.clientId, site.siteId, site.publicKey);
|
||||
}
|
||||
|
||||
if (!updatedClient || !sitesData) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
`Failed to update client`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let exitNodeDestinations: {
|
||||
reachableAt: string;
|
||||
exitNodeId: number;
|
||||
type: string;
|
||||
name: string;
|
||||
sourceIp: string;
|
||||
sourcePort: number;
|
||||
destinations: PeerDestination[];
|
||||
}[] = [];
|
||||
|
||||
for (const site of sitesData) {
|
||||
if (!site.sites.subnet) {
|
||||
logger.warn(
|
||||
`Site ${site.sites.siteId} has no subnet, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!site.clientSites.endpoint) {
|
||||
logger.warn(
|
||||
`Site ${site.sites.siteId} has no endpoint, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// find the destinations in the array
|
||||
let destinations = exitNodeDestinations.find(
|
||||
(d) => d.reachableAt === site.exitNodes?.reachableAt
|
||||
);
|
||||
|
||||
if (!destinations) {
|
||||
destinations = {
|
||||
reachableAt: site.exitNodes?.reachableAt || "",
|
||||
exitNodeId: site.exitNodes?.exitNodeId || 0,
|
||||
type: site.exitNodes?.type || "",
|
||||
name: site.exitNodes?.name || "",
|
||||
sourceIp: site.clientSites.endpoint.split(":")[0] || "",
|
||||
sourcePort:
|
||||
parseInt(site.clientSites.endpoint.split(":")[1]) || 0,
|
||||
destinations: [
|
||||
{
|
||||
destinationIP: site.sites.subnet.split("/")[0],
|
||||
destinationPort: site.sites.listenPort || 0
|
||||
}
|
||||
]
|
||||
};
|
||||
} else {
|
||||
// add to the existing destinations
|
||||
destinations.destinations.push({
|
||||
destinationIP: site.sites.subnet.split("/")[0],
|
||||
destinationPort: site.sites.listenPort || 0
|
||||
});
|
||||
}
|
||||
|
||||
// update it in the array
|
||||
exitNodeDestinations = exitNodeDestinations.filter(
|
||||
(d) => d.reachableAt !== site.exitNodes?.reachableAt
|
||||
);
|
||||
exitNodeDestinations.push(destinations);
|
||||
}
|
||||
|
||||
for (const destination of exitNodeDestinations) {
|
||||
logger.info(
|
||||
`Updating destinations for exit node at ${destination.reachableAt}`
|
||||
);
|
||||
const payload = {
|
||||
sourceIp: destination.sourceIp,
|
||||
sourcePort: destination.sourcePort,
|
||||
destinations: destination.destinations
|
||||
};
|
||||
logger.info(
|
||||
`Payload for update-destinations: ${JSON.stringify(payload, null, 2)}`
|
||||
);
|
||||
|
||||
// Create an ExitNode-like object for sendToExitNode
|
||||
const exitNodeForComm = {
|
||||
exitNodeId: destination.exitNodeId,
|
||||
type: destination.type,
|
||||
reachableAt: destination.reachableAt,
|
||||
name: destination.name
|
||||
} as any; // Using 'as any' since we know sendToExitNode will handle this correctly
|
||||
|
||||
await sendToExitNode(exitNodeForComm, {
|
||||
remoteType: "remoteExitNode/update-destinations",
|
||||
localPath: "/update-destinations",
|
||||
method: "POST",
|
||||
data: payload
|
||||
});
|
||||
}
|
||||
const updatedClient = await db
|
||||
.update(clients)
|
||||
.set({ name })
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.returning();
|
||||
|
||||
return response(res, {
|
||||
data: updatedClient,
|
||||
|
||||
@@ -16,6 +16,8 @@ import * as idp from "./idp";
|
||||
import * as blueprints from "./blueprints";
|
||||
import * as apiKeys from "./apiKeys";
|
||||
import * as logs from "./auditLogs";
|
||||
import * as newt from "./newt";
|
||||
import * as olm from "./olm";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import {
|
||||
verifyAccessTokenAccess,
|
||||
@@ -27,6 +29,7 @@ import {
|
||||
verifyTargetAccess,
|
||||
verifyRoleAccess,
|
||||
verifySetResourceUsers,
|
||||
verifySetResourceClients,
|
||||
verifyUserAccess,
|
||||
getUserOrgs,
|
||||
verifyUserIsServerAdmin,
|
||||
@@ -34,14 +37,12 @@ import {
|
||||
verifyClientAccess,
|
||||
verifyApiKeyAccess,
|
||||
verifyDomainAccess,
|
||||
verifyClientsEnabled,
|
||||
verifyUserHasAction,
|
||||
verifyUserIsOrgOwner,
|
||||
verifySiteResourceAccess
|
||||
verifySiteResourceAccess,
|
||||
verifyOlmAccess
|
||||
} from "@server/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
import { createNewt, getNewtToken } from "./newt";
|
||||
import { getOlmToken } from "./olm";
|
||||
import rateLimit, { ipKeyGenerator } from "express-rate-limit";
|
||||
import createHttpError from "http-errors";
|
||||
import { build } from "@server/build";
|
||||
@@ -129,7 +130,6 @@ authenticated.get(
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/pick-client-defaults",
|
||||
verifyClientsEnabled,
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.createClient),
|
||||
client.pickClientDefaults
|
||||
@@ -137,7 +137,6 @@ authenticated.get(
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/clients",
|
||||
verifyClientsEnabled,
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.listClients),
|
||||
client.listClients
|
||||
@@ -145,7 +144,6 @@ authenticated.get(
|
||||
|
||||
authenticated.get(
|
||||
"/client/:clientId",
|
||||
verifyClientsEnabled,
|
||||
verifyClientAccess,
|
||||
verifyUserHasAction(ActionsEnum.getClient),
|
||||
client.getClient
|
||||
@@ -153,16 +151,15 @@ authenticated.get(
|
||||
|
||||
authenticated.put(
|
||||
"/org/:orgId/client",
|
||||
verifyClientsEnabled,
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.createClient),
|
||||
logActionAudit(ActionsEnum.createClient),
|
||||
client.createClient
|
||||
);
|
||||
|
||||
// TODO: Separate into a deleteUserClient (for user clients) and deleteClient (for machine clients)
|
||||
authenticated.delete(
|
||||
"/client/:clientId",
|
||||
verifyClientsEnabled,
|
||||
verifyClientAccess,
|
||||
verifyUserHasAction(ActionsEnum.deleteClient),
|
||||
logActionAudit(ActionsEnum.deleteClient),
|
||||
@@ -171,7 +168,6 @@ authenticated.delete(
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId",
|
||||
verifyClientsEnabled,
|
||||
verifyClientAccess, // this will check if the user has access to the client
|
||||
verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client
|
||||
logActionAudit(ActionsEnum.updateClient),
|
||||
@@ -286,6 +282,72 @@ authenticated.delete(
|
||||
siteResource.deleteSiteResource
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/site-resource/:siteResourceId/roles",
|
||||
verifySiteResourceAccess,
|
||||
verifyUserHasAction(ActionsEnum.listResourceRoles),
|
||||
siteResource.listSiteResourceRoles
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/site-resource/:siteResourceId/users",
|
||||
verifySiteResourceAccess,
|
||||
verifyUserHasAction(ActionsEnum.listResourceUsers),
|
||||
siteResource.listSiteResourceUsers
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/site-resource/:siteResourceId/clients",
|
||||
verifySiteResourceAccess,
|
||||
verifyUserHasAction(ActionsEnum.listResourceUsers),
|
||||
siteResource.listSiteResourceClients
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/roles",
|
||||
verifySiteResourceAccess,
|
||||
verifyRoleAccess,
|
||||
verifyUserHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
siteResource.setSiteResourceRoles,
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/users",
|
||||
verifySiteResourceAccess,
|
||||
verifySetResourceUsers,
|
||||
verifyUserHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.setSiteResourceUsers,
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients",
|
||||
verifySiteResourceAccess,
|
||||
verifySetResourceClients,
|
||||
verifyUserHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.setSiteResourceClients,
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients/add",
|
||||
verifySiteResourceAccess,
|
||||
verifySetResourceClients,
|
||||
verifyUserHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.addClientToSiteResource,
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients/remove",
|
||||
verifySiteResourceAccess,
|
||||
verifySetResourceClients,
|
||||
verifyUserHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.removeClientFromSiteResource,
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/org/:orgId/resource",
|
||||
verifyOrgAccess,
|
||||
@@ -649,9 +711,15 @@ unauthenticated.get(
|
||||
// );
|
||||
|
||||
unauthenticated.get("/user", verifySessionMiddleware, user.getUser);
|
||||
unauthenticated.get("/my-device", verifySessionMiddleware, user.myDevice);
|
||||
|
||||
authenticated.get("/users", verifyUserIsServerAdmin, user.adminListUsers);
|
||||
authenticated.get("/user/:userId", verifyUserIsServerAdmin, user.adminGetUser);
|
||||
authenticated.post(
|
||||
"/user/:userId/generate-password-reset-code",
|
||||
verifyUserIsServerAdmin,
|
||||
user.adminGeneratePasswordResetCode
|
||||
);
|
||||
authenticated.delete(
|
||||
"/user/:userId",
|
||||
verifyUserIsServerAdmin,
|
||||
@@ -734,6 +802,32 @@ authenticated.delete(
|
||||
// createNewt
|
||||
// );
|
||||
|
||||
authenticated.put(
|
||||
"/user/:userId/olm",
|
||||
verifyIsLoggedInUser,
|
||||
olm.createUserOlm
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/user/:userId/olms",
|
||||
verifyIsLoggedInUser,
|
||||
olm.listUserOlms
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
"/user/:userId/olm/:olmId",
|
||||
verifyIsLoggedInUser,
|
||||
verifyOlmAccess,
|
||||
olm.deleteUserOlm
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/user/:userId/olm/:olmId",
|
||||
verifyIsLoggedInUser,
|
||||
verifyOlmAccess,
|
||||
olm.getUserOlm
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/idp/oidc",
|
||||
verifyUserIsServerAdmin,
|
||||
@@ -993,7 +1087,7 @@ authRouter.post(
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
getNewtToken
|
||||
newt.getNewtToken
|
||||
);
|
||||
authRouter.post(
|
||||
"/olm/get-token",
|
||||
@@ -1008,7 +1102,7 @@ authRouter.post(
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
getOlmToken
|
||||
olm.getOlmToken
|
||||
);
|
||||
|
||||
authRouter.post(
|
||||
@@ -1253,3 +1347,51 @@ authRouter.delete(
|
||||
}),
|
||||
auth.deleteSecurityKey
|
||||
);
|
||||
|
||||
authRouter.post(
|
||||
"/device-web-auth/start",
|
||||
rateLimit({
|
||||
windowMs: 15 * 60 * 1000, // 15 minutes
|
||||
max: 30, // Allow 30 device auth code requests per 15 minutes per IP
|
||||
keyGenerator: (req) =>
|
||||
`deviceWebAuthStart:${ipKeyGenerator(req.ip || "")}`,
|
||||
handler: (req, res, next) => {
|
||||
const message = `You can only request a device auth code ${30} times every ${15} minutes. Please try again later.`;
|
||||
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
auth.startDeviceWebAuth
|
||||
);
|
||||
|
||||
authRouter.get(
|
||||
"/device-web-auth/poll/:code",
|
||||
rateLimit({
|
||||
windowMs: 60 * 1000, // 1 minute
|
||||
max: 60, // Allow 60 polling requests per minute per IP (poll every second)
|
||||
keyGenerator: (req) =>
|
||||
`deviceWebAuthPoll:${ipKeyGenerator(req.ip || "")}:${req.params.code}`,
|
||||
handler: (req, res, next) => {
|
||||
const message = `You can only poll a device auth code ${60} times per minute. Please try again later.`;
|
||||
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
auth.pollDeviceWebAuth
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/device-web-auth/verify",
|
||||
rateLimit({
|
||||
windowMs: 15 * 60 * 1000, // 15 minutes
|
||||
max: 50, // Allow 50 verification attempts per 15 minutes per user
|
||||
keyGenerator: (req) =>
|
||||
`deviceWebAuthVerify:${req.user?.userId || ipKeyGenerator(req.ip || "")}`,
|
||||
handler: (req, res, next) => {
|
||||
const message = `You can only verify a device auth code ${50} times every ${15} minutes. Please try again later.`;
|
||||
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
auth.verifyDeviceWebAuth
|
||||
);
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
olms,
|
||||
Site,
|
||||
sites,
|
||||
clientSites,
|
||||
clientSitesAssociationsCache,
|
||||
ExitNode
|
||||
} from "@server/db";
|
||||
import { db } from "@server/db";
|
||||
@@ -109,8 +109,8 @@ export async function generateRelayMappings(exitNode: ExitNode) {
|
||||
// Find all clients associated with this site through clientSites
|
||||
const clientSitesRes = await db
|
||||
.select()
|
||||
.from(clientSites)
|
||||
.where(eq(clientSites.siteId, site.siteId));
|
||||
.from(clientSitesAssociationsCache)
|
||||
.where(eq(clientSitesAssociationsCache.siteId, site.siteId));
|
||||
|
||||
for (const clientSite of clientSitesRes) {
|
||||
if (!clientSite.endpoint) {
|
||||
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
olms,
|
||||
Site,
|
||||
sites,
|
||||
clientSites,
|
||||
clientSitesAssociationsCache,
|
||||
exitNodes,
|
||||
ExitNode
|
||||
} from "@server/db";
|
||||
@@ -19,6 +19,8 @@ import { fromError } from "zod-validation-error";
|
||||
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
|
||||
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
|
||||
import { checkExitNodeOrg } from "#dynamic/lib/exitNodes";
|
||||
import { updatePeer as updateOlmPeer } from "../olm/peers";
|
||||
import { updatePeer as updateNewtPeer } from "../newt/peers";
|
||||
|
||||
// Define Zod schema for request validation
|
||||
const updateHolePunchSchema = z.object({
|
||||
@@ -28,8 +30,9 @@ const updateHolePunchSchema = z.object({
|
||||
ip: z.string(),
|
||||
port: z.number(),
|
||||
timestamp: z.number(),
|
||||
publicKey: z.string(),
|
||||
reachableAt: z.string().optional(),
|
||||
publicKey: z.string().optional()
|
||||
exitNodePublicKey: z.string().optional()
|
||||
});
|
||||
|
||||
// New response type with multi-peer destination support
|
||||
@@ -63,23 +66,26 @@ export async function updateHolePunch(
|
||||
timestamp,
|
||||
token,
|
||||
reachableAt,
|
||||
publicKey
|
||||
publicKey, // this is the client's current public key for this session
|
||||
exitNodePublicKey
|
||||
} = parsedParams.data;
|
||||
|
||||
let exitNode: ExitNode | undefined;
|
||||
if (publicKey) {
|
||||
if (exitNodePublicKey) {
|
||||
// Get the exit node by public key
|
||||
[exitNode] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.publicKey, publicKey));
|
||||
.where(eq(exitNodes.publicKey, exitNodePublicKey));
|
||||
} else {
|
||||
// FOR BACKWARDS COMPATIBILITY IF GERBIL IS STILL =<1.1.0
|
||||
[exitNode] = await db.select().from(exitNodes).limit(1);
|
||||
}
|
||||
|
||||
if (!exitNode) {
|
||||
logger.warn(`Exit node not found for publicKey: ${publicKey}`);
|
||||
logger.warn(
|
||||
`Exit node not found for publicKey: ${exitNodePublicKey}`
|
||||
);
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Exit node not found")
|
||||
);
|
||||
@@ -92,12 +98,13 @@ export async function updateHolePunch(
|
||||
port,
|
||||
timestamp,
|
||||
token,
|
||||
publicKey,
|
||||
exitNode
|
||||
);
|
||||
|
||||
logger.debug(
|
||||
`Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}`
|
||||
);
|
||||
// logger.debug(
|
||||
// `Returning ${destinations.length} peer destinations for olmId: ${olmId} or newtId: ${newtId}: ${JSON.stringify(destinations, null, 2)}`
|
||||
// );
|
||||
|
||||
// Return the new multi-peer structure
|
||||
return res.status(HttpCode.OK).send({
|
||||
@@ -121,6 +128,7 @@ export async function updateAndGenerateEndpointDestinations(
|
||||
port: number,
|
||||
timestamp: number,
|
||||
token: string,
|
||||
publicKey: string,
|
||||
exitNode: ExitNode,
|
||||
checkOrg = false
|
||||
) {
|
||||
@@ -128,9 +136,9 @@ export async function updateAndGenerateEndpointDestinations(
|
||||
const destinations: PeerDestination[] = [];
|
||||
|
||||
if (olmId) {
|
||||
logger.debug(
|
||||
`Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}`
|
||||
);
|
||||
// logger.debug(
|
||||
// `Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}`
|
||||
// );
|
||||
|
||||
const { session, olm: olmSession } =
|
||||
await validateOlmSessionToken(token);
|
||||
@@ -150,7 +158,7 @@ export async function updateAndGenerateEndpointDestinations(
|
||||
throw new Error("Olm not found");
|
||||
}
|
||||
|
||||
const [client] = await db
|
||||
const [updatedClient] = await db
|
||||
.update(clients)
|
||||
.set({
|
||||
lastHolePunch: timestamp
|
||||
@@ -158,10 +166,16 @@ export async function updateAndGenerateEndpointDestinations(
|
||||
.where(eq(clients.clientId, olm.clientId))
|
||||
.returning();
|
||||
|
||||
if (await checkExitNodeOrg(exitNode.exitNodeId, client.orgId) && checkOrg) {
|
||||
if (
|
||||
(await checkExitNodeOrg(
|
||||
exitNode.exitNodeId,
|
||||
updatedClient.orgId
|
||||
)) &&
|
||||
checkOrg
|
||||
) {
|
||||
// not allowed
|
||||
logger.warn(
|
||||
`Exit node ${exitNode.exitNodeId} is not allowed for org ${client.orgId}`
|
||||
`Exit node ${exitNode.exitNodeId} is not allowed for org ${updatedClient.orgId}`
|
||||
);
|
||||
throw new Error("Exit node not allowed");
|
||||
}
|
||||
@@ -171,40 +185,70 @@ export async function updateAndGenerateEndpointDestinations(
|
||||
.select({
|
||||
siteId: sites.siteId,
|
||||
subnet: sites.subnet,
|
||||
listenPort: sites.listenPort
|
||||
listenPort: sites.listenPort,
|
||||
publicKey: sites.publicKey,
|
||||
endpoint: clientSitesAssociationsCache.endpoint
|
||||
})
|
||||
.from(sites)
|
||||
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(sites.exitNodeId, exitNode.exitNodeId),
|
||||
eq(clientSites.clientId, olm.clientId)
|
||||
eq(clientSitesAssociationsCache.clientId, olm.clientId)
|
||||
)
|
||||
);
|
||||
|
||||
// Update clientSites for each site on this exit node
|
||||
for (const site of sitesOnExitNode) {
|
||||
logger.debug(
|
||||
`Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}`
|
||||
);
|
||||
// logger.debug(
|
||||
// `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}`
|
||||
// );
|
||||
|
||||
await db
|
||||
.update(clientSites)
|
||||
// if the public key or endpoint has changed, update it otherwise continue
|
||||
if (
|
||||
site.endpoint === `${ip}:${port}` &&
|
||||
site.publicKey === publicKey
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const [updatedClientSitesAssociationsCache] = await db
|
||||
.update(clientSitesAssociationsCache)
|
||||
.set({
|
||||
endpoint: `${ip}:${port}`
|
||||
endpoint: `${ip}:${port}`,
|
||||
publicKey: publicKey
|
||||
})
|
||||
.where(
|
||||
and(
|
||||
eq(clientSites.clientId, olm.clientId),
|
||||
eq(clientSites.siteId, site.siteId)
|
||||
eq(clientSitesAssociationsCache.clientId, olm.clientId),
|
||||
eq(clientSitesAssociationsCache.siteId, site.siteId)
|
||||
)
|
||||
)
|
||||
.returning();
|
||||
|
||||
if (
|
||||
updatedClientSitesAssociationsCache.endpoint !==
|
||||
site.endpoint && // this is the endpoint from the join table not the site
|
||||
updatedClient.pubKey === publicKey // only trigger if the client's public key matches the current public key which means it has registered so we dont prematurely send the update
|
||||
) {
|
||||
logger.info(
|
||||
`ClientSitesAssociationsCache for client ${olm.clientId} and site ${site.siteId} endpoint changed from ${site.endpoint} to ${updatedClientSitesAssociationsCache.endpoint}`
|
||||
);
|
||||
// Handle any additional logic for endpoint change
|
||||
handleClientEndpointChange(
|
||||
olm.clientId,
|
||||
updatedClientSitesAssociationsCache.endpoint!
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}`
|
||||
);
|
||||
if (!client) {
|
||||
// logger.debug(
|
||||
// `Updated ${sitesOnExitNode.length} sites on exit node ${exitNode.exitNodeId}`
|
||||
// );
|
||||
if (!updatedClient) {
|
||||
logger.warn(`Client not found for olm: ${olmId}`);
|
||||
throw new Error("Client not found");
|
||||
}
|
||||
@@ -219,9 +263,9 @@ export async function updateAndGenerateEndpointDestinations(
|
||||
}
|
||||
}
|
||||
} else if (newtId) {
|
||||
logger.debug(
|
||||
`Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}`
|
||||
);
|
||||
// logger.debug(
|
||||
// `Got hole punch with ip: ${ip}, port: ${port} for newtId: ${newtId}`
|
||||
// );
|
||||
|
||||
const { session, newt: newtSession } =
|
||||
await validateNewtSessionToken(token);
|
||||
@@ -253,7 +297,10 @@ export async function updateAndGenerateEndpointDestinations(
|
||||
.where(eq(sites.siteId, newt.siteId))
|
||||
.limit(1);
|
||||
|
||||
if (await checkExitNodeOrg(exitNode.exitNodeId, site.orgId) && checkOrg) {
|
||||
if (
|
||||
(await checkExitNodeOrg(exitNode.exitNodeId, site.orgId)) &&
|
||||
checkOrg
|
||||
) {
|
||||
// not allowed
|
||||
logger.warn(
|
||||
`Exit node ${exitNode.exitNodeId} is not allowed for org ${site.orgId}`
|
||||
@@ -273,6 +320,18 @@ export async function updateAndGenerateEndpointDestinations(
|
||||
.where(eq(sites.siteId, newt.siteId))
|
||||
.returning();
|
||||
|
||||
if (
|
||||
updatedSite.endpoint != site.endpoint &&
|
||||
updatedSite.publicKey == publicKey
|
||||
) {
|
||||
// only trigger if the site's public key matches the current public key which means it has registered so we dont prematurely send the update
|
||||
logger.info(
|
||||
`Site ${newt.siteId} endpoint changed from ${site.endpoint} to ${updatedSite.endpoint}`
|
||||
);
|
||||
// Handle any additional logic for endpoint change
|
||||
handleSiteEndpointChange(newt.siteId, updatedSite.endpoint!);
|
||||
}
|
||||
|
||||
if (!updatedSite || !updatedSite.subnet) {
|
||||
logger.warn(`Site not found: ${newt.siteId}`);
|
||||
throw new Error("Site not found");
|
||||
@@ -326,3 +385,143 @@ export async function updateAndGenerateEndpointDestinations(
|
||||
}
|
||||
return destinations;
|
||||
}
|
||||
|
||||
async function handleSiteEndpointChange(siteId: number, newEndpoint: string) {
|
||||
// Alert all clients connected to this site that the endpoint has changed (only if NOT relayed)
|
||||
try {
|
||||
// Get site details
|
||||
const [site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
|
||||
if (!site || !site.publicKey) {
|
||||
logger.warn(`Site ${siteId} not found or has no public key`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Get all non-relayed clients connected to this site
|
||||
const connectedClients = await db
|
||||
.select({
|
||||
clientId: clients.clientId,
|
||||
olmId: olms.olmId,
|
||||
isRelayed: clientSitesAssociationsCache.isRelayed
|
||||
})
|
||||
.from(clientSitesAssociationsCache)
|
||||
.innerJoin(
|
||||
clients,
|
||||
eq(clientSitesAssociationsCache.clientId, clients.clientId)
|
||||
)
|
||||
.innerJoin(olms, eq(olms.clientId, clients.clientId))
|
||||
.where(
|
||||
and(
|
||||
eq(clientSitesAssociationsCache.siteId, siteId),
|
||||
eq(clientSitesAssociationsCache.isRelayed, false)
|
||||
)
|
||||
);
|
||||
|
||||
// Update each non-relayed client with the new site endpoint
|
||||
for (const client of connectedClients) {
|
||||
try {
|
||||
await updateOlmPeer(
|
||||
client.clientId,
|
||||
{
|
||||
siteId: siteId,
|
||||
publicKey: site.publicKey,
|
||||
endpoint: newEndpoint
|
||||
},
|
||||
client.olmId
|
||||
);
|
||||
logger.debug(
|
||||
`Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}`
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to update client ${client.clientId} with new site endpoint: ${error}`
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error handling site endpoint change for site ${siteId}: ${error}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleClientEndpointChange(
|
||||
clientId: number,
|
||||
newEndpoint: string
|
||||
) {
|
||||
// Alert all sites connected to this client that the endpoint has changed (only if NOT relayed)
|
||||
try {
|
||||
// Get client details
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.limit(1);
|
||||
|
||||
if (!client || !client.pubKey) {
|
||||
logger.warn(`Client ${clientId} not found or has no public key`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Get all non-relayed sites connected to this client
|
||||
const connectedSites = await db
|
||||
.select({
|
||||
siteId: sites.siteId,
|
||||
newtId: newts.newtId,
|
||||
isRelayed: clientSitesAssociationsCache.isRelayed,
|
||||
subnet: clients.subnet
|
||||
})
|
||||
.from(clientSitesAssociationsCache)
|
||||
.innerJoin(
|
||||
sites,
|
||||
eq(clientSitesAssociationsCache.siteId, sites.siteId)
|
||||
)
|
||||
.innerJoin(newts, eq(newts.siteId, sites.siteId))
|
||||
.innerJoin(
|
||||
clients,
|
||||
eq(clientSitesAssociationsCache.clientId, clients.clientId)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(clientSitesAssociationsCache.clientId, clientId),
|
||||
eq(clientSitesAssociationsCache.isRelayed, false)
|
||||
)
|
||||
);
|
||||
|
||||
// Update each non-relayed site with the new client endpoint
|
||||
for (const siteData of connectedSites) {
|
||||
try {
|
||||
if (!siteData.subnet) {
|
||||
logger.warn(
|
||||
`Client ${clientId} has no subnet, skipping update for site ${siteData.siteId}`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
await updateNewtPeer(
|
||||
siteData.siteId,
|
||||
client.pubKey,
|
||||
{
|
||||
endpoint: newEndpoint
|
||||
},
|
||||
siteData.newtId
|
||||
);
|
||||
logger.debug(
|
||||
`Updated site ${siteData.siteId} with new client ${clientId} endpoint: ${newEndpoint}`
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to update site ${siteData.siteId} with new client endpoint: ${error}`
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error handling client endpoint change for client ${clientId}: ${error}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ import { UserType } from "@server/types/UserTypes";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import { build } from "@server/build";
|
||||
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
|
||||
|
||||
const ensureTrailingSlash = (url: string): string => {
|
||||
return url;
|
||||
@@ -364,10 +365,18 @@ export async function validateOidcCallback(
|
||||
);
|
||||
|
||||
if (!existingUserOrgs.length) {
|
||||
// delete the user
|
||||
// await db
|
||||
// .delete(users)
|
||||
// .where(eq(users.userId, existingUser.userId));
|
||||
// delete all auto -provisioned user orgs
|
||||
await db
|
||||
.delete(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.userId, existingUser.userId),
|
||||
eq(userOrgs.autoProvisioned, true)
|
||||
)
|
||||
);
|
||||
|
||||
await calculateUserClientsForOrgs(existingUser.userId);
|
||||
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
@@ -513,6 +522,8 @@ export async function validateOidcCallback(
|
||||
userCount: userCount.length
|
||||
});
|
||||
}
|
||||
|
||||
await calculateUserClientsForOrgs(userId!, trx);
|
||||
});
|
||||
|
||||
for (const orgCount of orgUserCounts) {
|
||||
@@ -553,6 +564,24 @@ export async function validateOidcCallback(
|
||||
);
|
||||
}
|
||||
|
||||
// check for existing user orgs
|
||||
const existingUserOrgs = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(and(eq(userOrgs.userId, existingUser.userId)));
|
||||
|
||||
if (!existingUserOrgs.length) {
|
||||
logger.debug(
|
||||
"No existing user orgs found for non-auto-provisioned IdP"
|
||||
);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
`User with username ${userIdentifier} is unprovisioned. This user must be added to an organization before logging in.`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const token = generateSessionToken();
|
||||
const sess = await createSession(token, existingUser.userId);
|
||||
const isSecure = req.protocol === "https";
|
||||
|
||||
@@ -10,6 +10,7 @@ import * as client from "./client";
|
||||
import * as accessToken from "./accessToken";
|
||||
import * as apiKeys from "./apiKeys";
|
||||
import * as idp from "./idp";
|
||||
import * as logs from "./auditLogs";
|
||||
import * as siteResource from "./siteResource";
|
||||
import {
|
||||
verifyApiKey,
|
||||
@@ -24,8 +25,8 @@ import {
|
||||
verifyApiKeyAccessTokenAccess,
|
||||
verifyApiKeyIsRoot,
|
||||
verifyApiKeyClientAccess,
|
||||
verifyClientsEnabled,
|
||||
verifyApiKeySiteResourceAccess
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceClients
|
||||
} from "@server/middlewares";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { Router } from "express";
|
||||
@@ -197,6 +198,108 @@ authenticated.delete(
|
||||
siteResource.deleteSiteResource
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/site-resource/:siteResourceId/roles",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.listResourceRoles),
|
||||
siteResource.listSiteResourceRoles
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/site-resource/:siteResourceId/users",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.listResourceUsers),
|
||||
siteResource.listSiteResourceUsers
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/site-resource/:siteResourceId/clients",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.listResourceUsers),
|
||||
siteResource.listSiteResourceClients
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/roles",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
siteResource.setSiteResourceRoles
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/users",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceUsers,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.setSiteResourceUsers
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/roles/add",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
siteResource.addRoleToSiteResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/roles/remove",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
siteResource.removeRoleFromSiteResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/users/add",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceUsers,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.addUserToSiteResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/users/remove",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceUsers,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.removeUserFromSiteResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceClients,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.setSiteResourceClients
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients/add",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceClients,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.addClientToSiteResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients/remove",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceClients,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.removeClientFromSiteResource
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/org/:orgId/resource",
|
||||
verifyApiKeyOrgAccess,
|
||||
@@ -412,6 +515,42 @@ authenticated.post(
|
||||
resource.setResourceUsers
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/resource/:resourceId/roles/add",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
resource.addRoleToResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/resource/:resourceId/roles/remove",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
resource.removeRoleFromResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/resource/:resourceId/users/add",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyApiKeySetResourceUsers,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
resource.addUserToResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/resource/:resourceId/users/remove",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyApiKeySetResourceUsers,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
resource.removeUserFromResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/password`,
|
||||
verifyApiKeyResourceAccess,
|
||||
@@ -657,7 +796,6 @@ authenticated.get(
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/pick-client-defaults",
|
||||
verifyClientsEnabled,
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.createClient),
|
||||
client.pickClientDefaults
|
||||
@@ -665,7 +803,6 @@ authenticated.get(
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/clients",
|
||||
verifyClientsEnabled,
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.listClients),
|
||||
client.listClients
|
||||
@@ -673,7 +810,6 @@ authenticated.get(
|
||||
|
||||
authenticated.get(
|
||||
"/client/:clientId",
|
||||
verifyClientsEnabled,
|
||||
verifyApiKeyClientAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.getClient),
|
||||
client.getClient
|
||||
@@ -681,16 +817,24 @@ authenticated.get(
|
||||
|
||||
authenticated.put(
|
||||
"/org/:orgId/client",
|
||||
verifyClientsEnabled,
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.createClient),
|
||||
logActionAudit(ActionsEnum.createClient),
|
||||
client.createClient
|
||||
);
|
||||
|
||||
// authenticated.put(
|
||||
// "/org/:orgId/user/:userId/client",
|
||||
// verifyClientsEnabled,
|
||||
// verifyApiKeyOrgAccess,
|
||||
// verifyApiKeyUserAccess,
|
||||
// verifyApiKeyHasAction(ActionsEnum.createClient),
|
||||
// logActionAudit(ActionsEnum.createClient),
|
||||
// client.createUserClient
|
||||
// );
|
||||
|
||||
authenticated.delete(
|
||||
"/client/:clientId",
|
||||
verifyClientsEnabled,
|
||||
verifyApiKeyClientAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.deleteClient),
|
||||
logActionAudit(ActionsEnum.deleteClient),
|
||||
@@ -699,7 +843,6 @@ authenticated.delete(
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId",
|
||||
verifyClientsEnabled,
|
||||
verifyApiKeyClientAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateClient),
|
||||
logActionAudit(ActionsEnum.updateClient),
|
||||
@@ -713,3 +856,32 @@ authenticated.put(
|
||||
logActionAudit(ActionsEnum.applyBlueprint),
|
||||
blueprints.applyJSONBlueprint
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/request",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.viewLogs),
|
||||
logs.queryRequestAuditLogs
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/request/export",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.exportLogs),
|
||||
logActionAudit(ActionsEnum.exportLogs),
|
||||
logs.exportRequestAuditLogs
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/analytics",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.viewLogs),
|
||||
logs.queryRequestAnalytics
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/resource-names",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.listResources),
|
||||
resource.listAllResourceNames
|
||||
);
|
||||
|
||||
@@ -6,15 +6,15 @@ import {
|
||||
db,
|
||||
ExitNode,
|
||||
exitNodes,
|
||||
resources,
|
||||
siteResources,
|
||||
Target,
|
||||
targets
|
||||
clientSiteResourcesAssociationsCache
|
||||
} from "@server/db";
|
||||
import { clients, clientSites, Newt, sites } from "@server/db";
|
||||
import { eq, and, inArray } from "drizzle-orm";
|
||||
import { clients, clientSitesAssociationsCache, Newt, sites } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { updatePeer } from "../olm/peers";
|
||||
import { sendToExitNode } from "#dynamic/lib/exitNodes";
|
||||
import { generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
const inputSchema = z.object({
|
||||
publicKey: z.string(),
|
||||
@@ -66,7 +66,9 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
|
||||
|
||||
// we need to wait for hole punch success
|
||||
if (!existingSite.endpoint) {
|
||||
logger.debug(`In newt get config: existing site ${existingSite.siteId} has no endpoint, skipping`);
|
||||
logger.debug(
|
||||
`In newt get config: existing site ${existingSite.siteId} has no endpoint, skipping`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -74,12 +76,12 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
|
||||
// TODO: somehow we should make sure a recent hole punch has happened if this occurs (hole punch could be from the last restart if done quickly)
|
||||
}
|
||||
|
||||
// if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 6) {
|
||||
// logger.warn(
|
||||
// `Site ${existingSite.siteId} last hole punch is too old, skipping`
|
||||
// );
|
||||
// return;
|
||||
// }
|
||||
if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 5) {
|
||||
logger.warn(
|
||||
`handleGetConfigMessage: Site ${existingSite.siteId} last hole punch is too old, skipping`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// update the endpoint and the public key
|
||||
const [site] = await db
|
||||
@@ -132,75 +134,95 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
|
||||
const clientsRes = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.innerJoin(clientSites, eq(clients.clientId, clientSites.clientId))
|
||||
.where(eq(clientSites.siteId, siteId));
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
eq(clients.clientId, clientSitesAssociationsCache.clientId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.siteId, siteId));
|
||||
|
||||
// Prepare peers data for the response
|
||||
const peers = await Promise.all(
|
||||
clientsRes
|
||||
.filter((client) => {
|
||||
if (!client.clients.pubKey) {
|
||||
logger.warn(
|
||||
`Client ${client.clients.clientId} has no public key, skipping`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
if (!client.clients.subnet) {
|
||||
logger.warn(
|
||||
`Client ${client.clients.clientId} has no subnet, skipping`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.map(async (client) => {
|
||||
// Add or update this peer on the olm if it is connected
|
||||
try {
|
||||
if (!site.publicKey) {
|
||||
logger.warn(
|
||||
`Site ${site.siteId} has no public key, skipping`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
let endpoint = site.endpoint;
|
||||
if (client.clientSites.isRelayed) {
|
||||
if (!site.exitNodeId) {
|
||||
logger.warn(
|
||||
`Site ${site.siteId} has no exit node, skipping`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!exitNode) {
|
||||
logger.warn(
|
||||
`Exit node not found for site ${site.siteId}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
endpoint = `${exitNode.endpoint}:21820`;
|
||||
}
|
||||
|
||||
if (!endpoint) {
|
||||
logger.warn(
|
||||
`In Newt get config: Peer site ${site.siteId} has no endpoint, skipping`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
await updatePeer(client.clients.clientId, {
|
||||
siteId: site.siteId,
|
||||
endpoint: endpoint,
|
||||
publicKey: site.publicKey,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort,
|
||||
remoteSubnets: site.remoteSubnets
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to add/update peer ${client.clients.pubKey} to olm ${newt.newtId}: ${error}`
|
||||
if (!site.publicKey) {
|
||||
logger.warn(
|
||||
`Site ${site.siteId} has no public key, skipping`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!exitNode) {
|
||||
logger.warn(`Exit node not found for site ${site.siteId}`);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!site.endpoint) {
|
||||
logger.warn(
|
||||
`Site ${site.siteId} has no endpoint, skipping`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
// const allSiteResources = await db // only get the site resources that this client has access to
|
||||
// .select()
|
||||
// .from(siteResources)
|
||||
// .innerJoin(
|
||||
// clientSiteResourcesAssociationsCache,
|
||||
// eq(
|
||||
// siteResources.siteResourceId,
|
||||
// clientSiteResourcesAssociationsCache.siteResourceId
|
||||
// )
|
||||
// )
|
||||
// .where(
|
||||
// and(
|
||||
// eq(siteResources.siteId, site.siteId),
|
||||
// eq(
|
||||
// clientSiteResourcesAssociationsCache.clientId,
|
||||
// client.clients.clientId
|
||||
// )
|
||||
// )
|
||||
// );
|
||||
await updatePeer(client.clients.clientId, {
|
||||
siteId: site.siteId,
|
||||
endpoint: site.endpoint,
|
||||
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
|
||||
publicKey: site.publicKey,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort
|
||||
// remoteSubnets: generateRemoteSubnets(
|
||||
// allSiteResources.map(
|
||||
// ({ siteResources }) => siteResources
|
||||
// )
|
||||
// ),
|
||||
// aliases: generateAliasConfig(
|
||||
// allSiteResources.map(
|
||||
// ({ siteResources }) => siteResources
|
||||
// )
|
||||
// )
|
||||
});
|
||||
|
||||
return {
|
||||
publicKey: client.clients.pubKey!,
|
||||
allowedIps: [`${client.clients.subnet.split("/")[0]}/32`], // we want to only allow from that client
|
||||
endpoint: client.clientSites.isRelayed
|
||||
endpoint: client.clientSitesAssociationsCache.isRelayed
|
||||
? ""
|
||||
: client.clientSites.endpoint! // if its relayed it should be localhost
|
||||
: client.clientSitesAssociationsCache.endpoint! // if its relayed it should be localhost
|
||||
};
|
||||
})
|
||||
);
|
||||
@@ -208,42 +230,50 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
|
||||
// Filter out any null values from peers that didn't have an olm
|
||||
const validPeers = peers.filter((peer) => peer !== null);
|
||||
|
||||
// Get all enabled targets with their resource protocol information
|
||||
// Get all enabled site resources for this site
|
||||
const allSiteResources = await db
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.where(eq(siteResources.siteId, siteId));
|
||||
|
||||
const { tcpTargets, udpTargets } = allSiteResources.reduce(
|
||||
(acc, resource) => {
|
||||
// Filter out invalid targets
|
||||
if (!resource.proxyPort || !resource.destinationIp || !resource.destinationPort) {
|
||||
return acc;
|
||||
}
|
||||
const targetsToSend: SubnetProxyTarget[] = [];
|
||||
|
||||
// Format target into string
|
||||
const formattedTarget = `${resource.proxyPort}:${resource.destinationIp}:${resource.destinationPort}`;
|
||||
for (const resource of allSiteResources) {
|
||||
// Get clients associated with this specific resource
|
||||
const resourceClients = await db
|
||||
.select({
|
||||
clientId: clients.clientId,
|
||||
pubKey: clients.pubKey,
|
||||
subnet: clients.subnet
|
||||
})
|
||||
.from(clients)
|
||||
.innerJoin(
|
||||
clientSiteResourcesAssociationsCache,
|
||||
eq(
|
||||
clients.clientId,
|
||||
clientSiteResourcesAssociationsCache.clientId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.siteResourceId,
|
||||
resource.siteResourceId
|
||||
)
|
||||
);
|
||||
|
||||
// Add to the appropriate protocol array
|
||||
if (resource.protocol === "tcp") {
|
||||
acc.tcpTargets.push(formattedTarget);
|
||||
} else {
|
||||
acc.udpTargets.push(formattedTarget);
|
||||
}
|
||||
const resourceTargets = generateSubnetProxyTargets(
|
||||
resource,
|
||||
resourceClients
|
||||
);
|
||||
|
||||
return acc;
|
||||
},
|
||||
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
|
||||
);
|
||||
targetsToSend.push(...resourceTargets);
|
||||
}
|
||||
|
||||
// Build the configuration response
|
||||
const configResponse = {
|
||||
ipAddress: site.address,
|
||||
peers: validPeers,
|
||||
targets: {
|
||||
udp: udpTargets,
|
||||
tcp: tcpTargets
|
||||
}
|
||||
targets: targetsToSend
|
||||
};
|
||||
|
||||
logger.debug("Sending config: ", configResponse);
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { db, exitNodeOrgs, newts } from "@server/db";
|
||||
import { db, ExitNode, exitNodeOrgs, newts, Transaction } from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { exitNodes, Newt, resources, sites, Target, targets } from "@server/db";
|
||||
import { targetHealthCheck } from "@server/db";
|
||||
import { eq, and, sql, inArray } from "drizzle-orm";
|
||||
import { eq, and, sql, inArray, ne } from "drizzle-orm";
|
||||
import { addPeer, deletePeer } from "../gerbil/peers";
|
||||
import logger from "@server/logger";
|
||||
import config from "@server/lib/config";
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
verifyExitNodeOrgAccess
|
||||
} from "#dynamic/lib/exitNodes";
|
||||
import { fetchContainers } from "./dockerSocket";
|
||||
import { lockManager } from "#dynamic/lib/lock";
|
||||
|
||||
export type ExitNodePingResult = {
|
||||
exitNodeId: number;
|
||||
@@ -151,27 +152,8 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const sitesQuery = await db
|
||||
.select({
|
||||
subnet: sites.subnet
|
||||
})
|
||||
.from(sites)
|
||||
.where(eq(sites.exitNodeId, exitNodeId));
|
||||
const newSubnet = await getUniqueSubnetForSite(exitNode);
|
||||
|
||||
const blockSize = config.getRawConfig().gerbil.site_block_size;
|
||||
const subnets = sitesQuery
|
||||
.map((site) => site.subnet)
|
||||
.filter(
|
||||
(subnet) =>
|
||||
subnet && /^(\d{1,3}\.){3}\d{1,3}\/\d{1,2}$/.test(subnet)
|
||||
)
|
||||
.filter((subnet) => subnet !== null);
|
||||
subnets.push(exitNode.address.replace(/\/\d+$/, `/${blockSize}`));
|
||||
const newSubnet = findNextAvailableCidr(
|
||||
subnets,
|
||||
blockSize,
|
||||
exitNode.address
|
||||
);
|
||||
if (!newSubnet) {
|
||||
logger.error(
|
||||
`No available subnets found for the new exit node id ${exitNodeId} and site id ${siteId}`
|
||||
@@ -378,3 +360,39 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
||||
excludeSender: false // Include sender in broadcast
|
||||
};
|
||||
};
|
||||
|
||||
async function getUniqueSubnetForSite(
|
||||
exitNode: ExitNode,
|
||||
trx: Transaction | typeof db = db
|
||||
): Promise<string | null> {
|
||||
const lockKey = `subnet-allocation:${exitNode.exitNodeId}`;
|
||||
|
||||
return await lockManager.withLock(
|
||||
lockKey,
|
||||
async () => {
|
||||
const sitesQuery = await trx
|
||||
.select({
|
||||
subnet: sites.subnet
|
||||
})
|
||||
.from(sites)
|
||||
.where(eq(sites.exitNodeId, exitNode.exitNodeId));
|
||||
|
||||
const blockSize = config.getRawConfig().gerbil.site_block_size;
|
||||
const subnets = sitesQuery
|
||||
.map((site) => site.subnet)
|
||||
.filter(
|
||||
(subnet) =>
|
||||
subnet && /^(\d{1,3}\.){3}\d{1,3}\/\d{1,2}$/.test(subnet)
|
||||
)
|
||||
.filter((subnet) => subnet !== null);
|
||||
subnets.push(exitNode.address.replace(/\/\d+$/, `/${blockSize}`));
|
||||
const newSubnet = findNextAvailableCidr(
|
||||
subnets,
|
||||
blockSize,
|
||||
exitNode.address
|
||||
);
|
||||
return newSubnet;
|
||||
},
|
||||
5000 // 5 second lock TTL - subnet allocation should be quick
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { db } from "@server/db";
|
||||
import { db, Site } from "@server/db";
|
||||
import { newts, sites } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
@@ -10,65 +10,78 @@ export async function addPeer(
|
||||
publicKey: string;
|
||||
allowedIps: string[];
|
||||
endpoint: string;
|
||||
}
|
||||
},
|
||||
newtId?: string
|
||||
) {
|
||||
const [site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!site) {
|
||||
throw new Error(`Exit node with ID ${siteId} not found`);
|
||||
let site: Site | null = null;
|
||||
if (!newtId) {
|
||||
[site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!site) {
|
||||
throw new Error(`Site with ID ${siteId} not found`);
|
||||
}
|
||||
|
||||
// get the newt on the site
|
||||
const [newt] = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!newt) {
|
||||
throw new Error(`Site found for site ${siteId}`);
|
||||
}
|
||||
newtId = newt.newtId;
|
||||
}
|
||||
|
||||
// get the newt on the site
|
||||
const [newt] = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!newt) {
|
||||
throw new Error(`Site found for site ${siteId}`);
|
||||
}
|
||||
|
||||
sendToClient(newt.newtId, {
|
||||
await sendToClient(newtId, {
|
||||
type: "newt/wg/peer/add",
|
||||
data: peer
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
logger.info(`Added peer ${peer.publicKey} to newt ${newt.newtId}`);
|
||||
logger.info(`Added peer ${peer.publicKey} to newt ${newtId}`);
|
||||
|
||||
return site;
|
||||
}
|
||||
|
||||
export async function deletePeer(siteId: number, publicKey: string) {
|
||||
const [site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!site) {
|
||||
throw new Error(`Site with ID ${siteId} not found`);
|
||||
export async function deletePeer(siteId: number, publicKey: string, newtId?: string) {
|
||||
let site: Site | null = null;
|
||||
if (!newtId) {
|
||||
[site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!site) {
|
||||
throw new Error(`Site with ID ${siteId} not found`);
|
||||
}
|
||||
|
||||
// get the newt on the site
|
||||
const [newt] = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!newt) {
|
||||
throw new Error(`Newt not found for site ${siteId}`);
|
||||
}
|
||||
newtId = newt.newtId;
|
||||
}
|
||||
|
||||
// get the newt on the site
|
||||
const [newt] = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!newt) {
|
||||
throw new Error(`Newt not found for site ${siteId}`);
|
||||
}
|
||||
|
||||
sendToClient(newt.newtId, {
|
||||
await sendToClient(newtId, {
|
||||
type: "newt/wg/peer/remove",
|
||||
data: {
|
||||
publicKey
|
||||
}
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
logger.info(`Deleted peer ${publicKey} from newt ${newt.newtId}`);
|
||||
logger.info(`Deleted peer ${publicKey} from newt ${newtId}`);
|
||||
|
||||
return site;
|
||||
}
|
||||
@@ -79,36 +92,43 @@ export async function updatePeer(
|
||||
peer: {
|
||||
allowedIps?: string[];
|
||||
endpoint?: string;
|
||||
}
|
||||
},
|
||||
newtId?: string
|
||||
) {
|
||||
const [site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!site) {
|
||||
throw new Error(`Site with ID ${siteId} not found`);
|
||||
let site: Site | null = null;
|
||||
if (!newtId) {
|
||||
[site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!site) {
|
||||
throw new Error(`Site with ID ${siteId} not found`);
|
||||
}
|
||||
|
||||
// get the newt on the site
|
||||
const [newt] = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!newt) {
|
||||
throw new Error(`Newt not found for site ${siteId}`);
|
||||
}
|
||||
newtId = newt.newtId;
|
||||
}
|
||||
|
||||
// get the newt on the site
|
||||
const [newt] = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, siteId))
|
||||
.limit(1);
|
||||
if (!newt) {
|
||||
throw new Error(`Newt not found for site ${siteId}`);
|
||||
}
|
||||
|
||||
sendToClient(newt.newtId, {
|
||||
await sendToClient(newtId, {
|
||||
type: "newt/wg/peer/update",
|
||||
data: {
|
||||
publicKey,
|
||||
...peer
|
||||
}
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
logger.info(`Updated peer ${publicKey} on newt ${newt.newtId}`);
|
||||
logger.info(`Updated peer ${publicKey} on newt ${newtId}`);
|
||||
|
||||
return site;
|
||||
}
|
||||
|
||||
116
server/routers/olm/createUserOlm.ts
Normal file
116
server/routers/olm/createUserOlm.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db, olms } from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { z } from "zod";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
import moment from "moment";
|
||||
import { generateId } from "@server/auth/sessions/app";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
|
||||
|
||||
const bodySchema = z
|
||||
.object({
|
||||
name: z.string().min(1).max(255)
|
||||
})
|
||||
.strict();
|
||||
|
||||
const paramsSchema = z.object({
|
||||
userId: z.string()
|
||||
});
|
||||
|
||||
export type CreateOlmBody = z.infer<typeof bodySchema>;
|
||||
|
||||
export type CreateOlmResponse = {
|
||||
olmId: string;
|
||||
secret: string;
|
||||
};
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "put",
|
||||
// path: "/user/{userId}/olm",
|
||||
// description: "Create a new olm for a user.",
|
||||
// tags: [OpenAPITags.User, OpenAPITags.Client],
|
||||
// request: {
|
||||
// body: {
|
||||
// content: {
|
||||
// "application/json": {
|
||||
// schema: bodySchema
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// params: paramsSchema
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export async function createUserOlm(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { name } = parsedBody.data;
|
||||
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { userId } = parsedParams.data;
|
||||
|
||||
const olmId = generateId(15);
|
||||
const secret = generateId(48);
|
||||
|
||||
const secretHash = await hashPassword(secret);
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
await trx.insert(olms).values({
|
||||
olmId: olmId,
|
||||
userId,
|
||||
name,
|
||||
secretHash,
|
||||
dateCreated: moment().toISOString()
|
||||
});
|
||||
|
||||
await calculateUserClientsForOrgs(userId, trx);
|
||||
});
|
||||
|
||||
return response<CreateOlmResponse>(res, {
|
||||
data: {
|
||||
olmId,
|
||||
secret
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Olm created successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to create olm"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
101
server/routers/olm/deleteUserOlm.ts
Normal file
101
server/routers/olm/deleteUserOlm.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { Client, db } from "@server/db";
|
||||
import { olms, clients, clientSitesAssociationsCache } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
|
||||
import { sendTerminateClient } from "../client/terminate";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
userId: z.string(),
|
||||
olmId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "delete",
|
||||
// path: "/user/{userId}/olm/{olmId}",
|
||||
// description: "Delete an olm for a user.",
|
||||
// tags: [OpenAPITags.User, OpenAPITags.Client],
|
||||
// request: {
|
||||
// params: paramsSchema
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export async function deleteUserOlm(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId } = parsedParams.data;
|
||||
|
||||
// Delete associated clients and the OLM in a transaction
|
||||
await db.transaction(async (trx) => {
|
||||
// Find all clients associated with this OLM
|
||||
const associatedClients = await trx
|
||||
.select({ clientId: clients.clientId })
|
||||
.from(clients)
|
||||
.where(eq(clients.olmId, olmId));
|
||||
|
||||
let deletedClient: Client | null = null;
|
||||
// Delete all associated clients
|
||||
if (associatedClients.length > 0) {
|
||||
[deletedClient] = await trx
|
||||
.delete(clients)
|
||||
.where(eq(clients.olmId, olmId))
|
||||
.returning();
|
||||
}
|
||||
|
||||
// Finally, delete the OLM itself
|
||||
const [olm] = await trx
|
||||
.delete(olms)
|
||||
.where(eq(olms.olmId, olmId))
|
||||
.returning();
|
||||
|
||||
if (deletedClient) {
|
||||
await rebuildClientAssociationsFromClient(deletedClient, trx);
|
||||
if (olm) {
|
||||
await sendTerminateClient(
|
||||
deletedClient.clientId,
|
||||
olm.olmId
|
||||
); // the olmId needs to be provided because it cant look it up after deletion
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Device deleted successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to delete device"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,16 @@
|
||||
import { generateSessionToken } from "@server/auth/sessions/app";
|
||||
import { db } from "@server/db";
|
||||
import {
|
||||
clients,
|
||||
db,
|
||||
ExitNode,
|
||||
exitNodes,
|
||||
sites,
|
||||
clientSitesAssociationsCache
|
||||
} from "@server/db";
|
||||
import { olms } from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import response from "@server/lib/response";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { and, eq, inArray } from "drizzle-orm";
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
@@ -19,7 +26,8 @@ import config from "@server/lib/config";
|
||||
export const olmGetTokenBodySchema = z.object({
|
||||
olmId: z.string(),
|
||||
secret: z.string(),
|
||||
token: z.string().optional()
|
||||
token: z.string().optional(),
|
||||
orgId: z.string().optional()
|
||||
});
|
||||
|
||||
export type OlmGetTokenBody = z.infer<typeof olmGetTokenBodySchema>;
|
||||
@@ -40,7 +48,7 @@ export async function getOlmToken(
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId, secret, token } = parsedBody.data;
|
||||
const { olmId, secret, token, orgId } = parsedBody.data;
|
||||
|
||||
try {
|
||||
if (token) {
|
||||
@@ -61,11 +69,12 @@ export async function getOlmToken(
|
||||
}
|
||||
}
|
||||
|
||||
const existingOlmRes = await db
|
||||
const [existingOlm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.olmId, olmId));
|
||||
if (!existingOlmRes || !existingOlmRes.length) {
|
||||
|
||||
if (!existingOlm) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
@@ -74,12 +83,11 @@ export async function getOlmToken(
|
||||
);
|
||||
}
|
||||
|
||||
const existingOlm = existingOlmRes[0];
|
||||
|
||||
const validSecret = await verifyPassword(
|
||||
secret,
|
||||
existingOlm.secretHash
|
||||
);
|
||||
|
||||
if (!validSecret) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
@@ -96,11 +104,115 @@ export async function getOlmToken(
|
||||
const resToken = generateSessionToken();
|
||||
await createOlmSession(resToken, existingOlm.olmId);
|
||||
|
||||
let orgIdToUse = orgId;
|
||||
let clientIdToUse;
|
||||
if (!orgIdToUse) {
|
||||
if (!existingOlm.clientId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Olm is not associated with a client, orgId is required"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, existingOlm.clientId))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Olm's associated client not found, orgId is required"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
orgIdToUse = client.orgId;
|
||||
clientIdToUse = client.clientId;
|
||||
} else {
|
||||
// we did provide the org
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(
|
||||
and(eq(clients.orgId, orgIdToUse), eq(clients.olmId, olmId))
|
||||
) // we want to lock on to the client with this olmId otherwise it can get assigned to a random one
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"No client found for provided orgId"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (existingOlm.clientId !== client.clientId) {
|
||||
// we only need to do this if the client is changing
|
||||
|
||||
logger.debug(
|
||||
`Switching olm client ${existingOlm.olmId} to org ${orgId} for user ${existingOlm.userId}`
|
||||
);
|
||||
|
||||
await db
|
||||
.update(olms)
|
||||
.set({
|
||||
clientId: client.clientId
|
||||
})
|
||||
.where(eq(olms.olmId, existingOlm.olmId));
|
||||
}
|
||||
|
||||
clientIdToUse = client.clientId;
|
||||
}
|
||||
|
||||
// Get all exit nodes from sites where the client has peers
|
||||
const clientSites = await db
|
||||
.select()
|
||||
.from(clientSitesAssociationsCache)
|
||||
.innerJoin(
|
||||
sites,
|
||||
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!));
|
||||
|
||||
// Extract unique exit node IDs
|
||||
const exitNodeIds = Array.from(
|
||||
new Set(
|
||||
clientSites
|
||||
.map(({ sites: site }) => site.exitNodeId)
|
||||
.filter((id): id is number => id !== null)
|
||||
)
|
||||
);
|
||||
|
||||
let allExitNodes: ExitNode[] = [];
|
||||
if (exitNodeIds.length > 0) {
|
||||
allExitNodes = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(inArray(exitNodes.exitNodeId, exitNodeIds));
|
||||
}
|
||||
|
||||
const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => {
|
||||
return {
|
||||
publicKey: exitNode.publicKey,
|
||||
endpoint: exitNode.endpoint
|
||||
};
|
||||
});
|
||||
|
||||
logger.debug("Token created successfully");
|
||||
|
||||
return response<{ token: string }>(res, {
|
||||
return response<{
|
||||
token: string;
|
||||
exitNodes: { publicKey: string; endpoint: string }[];
|
||||
}>(res, {
|
||||
data: {
|
||||
token: resToken
|
||||
token: resToken,
|
||||
exitNodes: exitNodesHpData
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
|
||||
70
server/routers/olm/getUserOlm.ts
Normal file
70
server/routers/olm/getUserOlm.ts
Normal file
@@ -0,0 +1,70 @@
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db } from "@server/db";
|
||||
import { olms } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
userId: z.string(),
|
||||
olmId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "get",
|
||||
// path: "/user/{userId}/olm/{olmId}",
|
||||
// description: "Get an olm for a user.",
|
||||
// tags: [OpenAPITags.User, OpenAPITags.Client],
|
||||
// request: {
|
||||
// params: paramsSchema
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export async function getUserOlm(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId, userId } = parsedParams.data;
|
||||
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(and(eq(olms.userId, userId), eq(olms.olmId, olmId)));
|
||||
|
||||
return response(res, {
|
||||
data: olm,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Successfully retrieved olm",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to retrieve olm"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,12 @@
|
||||
import { db } from "@server/db";
|
||||
import { disconnectClient } from "#dynamic/routers/ws";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { clients, Olm } from "@server/db";
|
||||
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
import { sendTerminateClient } from "../client/terminate";
|
||||
|
||||
// Track if the offline checker interval is running
|
||||
let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
||||
@@ -20,10 +24,14 @@ export const startOlmOfflineChecker = (): void => {
|
||||
|
||||
offlineCheckerInterval = setInterval(async () => {
|
||||
try {
|
||||
const twoMinutesAgo = Math.floor((Date.now() - OFFLINE_THRESHOLD_MS) / 1000);
|
||||
const twoMinutesAgo = Math.floor(
|
||||
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
|
||||
);
|
||||
|
||||
// TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING
|
||||
|
||||
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
|
||||
await db
|
||||
const offlineClients = await db
|
||||
.update(clients)
|
||||
.set({ online: false })
|
||||
.where(
|
||||
@@ -34,8 +42,34 @@ export const startOlmOfflineChecker = (): void => {
|
||||
isNull(clients.lastPing)
|
||||
)
|
||||
)
|
||||
)
|
||||
.returning();
|
||||
|
||||
for (const offlineClient of offlineClients) {
|
||||
logger.info(
|
||||
`Kicking offline olm client ${offlineClient.clientId} due to inactivity`
|
||||
);
|
||||
|
||||
if (!offlineClient.olmId) {
|
||||
logger.warn(
|
||||
`Offline client ${offlineClient.clientId} has no olmId, cannot disconnect`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Send a disconnect message to the client if connected
|
||||
try {
|
||||
await sendTerminateClient(offlineClient.clientId, offlineClient.olmId); // terminate first
|
||||
// wait a moment to ensure the message is sent
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
await disconnectClient(offlineClient.olmId);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error sending disconnect to offline olm ${offlineClient.clientId}`,
|
||||
{ error }
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Error in offline checker interval", { error });
|
||||
}
|
||||
@@ -62,11 +96,57 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const olm = c as Olm;
|
||||
|
||||
const { userToken } = message.data;
|
||||
|
||||
if (!olm) {
|
||||
logger.warn("Olm not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (olm.userId) {
|
||||
// we need to check a user token to make sure its still valid
|
||||
const { session: userSession, user } =
|
||||
await validateSessionToken(userToken);
|
||||
if (!userSession || !user) {
|
||||
logger.warn("Invalid user session for olm ping");
|
||||
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
|
||||
}
|
||||
if (user.userId !== olm.userId) {
|
||||
logger.warn("User ID mismatch for olm ping");
|
||||
return;
|
||||
}
|
||||
|
||||
// get the client
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(
|
||||
and(
|
||||
eq(clients.olmId, olm.olmId),
|
||||
eq(clients.userId, olm.userId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
logger.warn("Client not found for olm ping");
|
||||
return;
|
||||
}
|
||||
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: client.orgId,
|
||||
userId: olm.userId,
|
||||
session: userToken // this is the user token passed in the message
|
||||
});
|
||||
|
||||
if (!policyCheck.allowed) {
|
||||
logger.warn(
|
||||
`Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!olm.clientId) {
|
||||
logger.warn("Olm has no client ID!");
|
||||
return;
|
||||
@@ -78,7 +158,7 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
|
||||
.update(clients)
|
||||
.set({
|
||||
lastPing: Math.floor(Date.now() / 1000),
|
||||
online: true,
|
||||
online: true
|
||||
})
|
||||
.where(eq(clients.clientId, olm.clientId));
|
||||
} catch (error) {
|
||||
@@ -89,7 +169,7 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
|
||||
message: {
|
||||
type: "pong",
|
||||
data: {
|
||||
timestamp: new Date().toISOString(),
|
||||
timestamp: new Date().toISOString()
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
|
||||
@@ -1,31 +1,115 @@
|
||||
import { db, ExitNode } from "@server/db";
|
||||
import {
|
||||
Client,
|
||||
clientSiteResourcesAssociationsCache,
|
||||
db,
|
||||
ExitNode,
|
||||
Org,
|
||||
orgs,
|
||||
roleClients,
|
||||
roles,
|
||||
siteResources,
|
||||
Transaction,
|
||||
userClients,
|
||||
userOrgs,
|
||||
users
|
||||
} from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db";
|
||||
import { and, eq, inArray } from "drizzle-orm";
|
||||
import {
|
||||
clients,
|
||||
clientSitesAssociationsCache,
|
||||
exitNodes,
|
||||
Olm,
|
||||
olms,
|
||||
sites
|
||||
} from "@server/db";
|
||||
import { and, eq, inArray, isNull } from "drizzle-orm";
|
||||
import { addPeer, deletePeer } from "../newt/peers";
|
||||
import logger from "@server/logger";
|
||||
import { listExitNodes } from "#dynamic/lib/exitNodes";
|
||||
import {
|
||||
generateAliasConfig,
|
||||
getNextAvailableClientSubnet
|
||||
} from "@server/lib/ip";
|
||||
import { generateRemoteSubnets } from "@server/lib/ip";
|
||||
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
logger.info("Handling register olm message!");
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const olm = c as Olm;
|
||||
|
||||
const now = new Date().getTime() / 1000;
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
|
||||
if (!olm) {
|
||||
logger.warn("Olm not found");
|
||||
return;
|
||||
}
|
||||
|
||||
const { publicKey, relay, olmVersion, olmAgent, orgId, userToken } = message.data;
|
||||
|
||||
if (!olm.clientId) {
|
||||
logger.warn("Olm has no client ID!");
|
||||
logger.warn("Olm client ID not found");
|
||||
return;
|
||||
}
|
||||
const clientId = olm.clientId;
|
||||
const { publicKey, relay, olmVersion } = message.data;
|
||||
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, olm.clientId))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
logger.warn("Client ID not found");
|
||||
return;
|
||||
}
|
||||
|
||||
const [org] = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, client.orgId))
|
||||
.limit(1);
|
||||
|
||||
if (!org) {
|
||||
logger.warn("Org not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (orgId) {
|
||||
if (!olm.userId) {
|
||||
logger.warn("Olm has no user ID");
|
||||
return;
|
||||
}
|
||||
|
||||
const { session: userSession, user } =
|
||||
await validateSessionToken(userToken);
|
||||
if (!userSession || !user) {
|
||||
logger.warn("Invalid user session for olm register");
|
||||
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
|
||||
}
|
||||
if (user.userId !== olm.userId) {
|
||||
logger.warn("User ID mismatch for olm register");
|
||||
return;
|
||||
}
|
||||
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: orgId,
|
||||
userId: olm.userId,
|
||||
session: userToken // this is the user token passed in the message
|
||||
});
|
||||
|
||||
if (!policyCheck.allowed) {
|
||||
logger.warn(
|
||||
`Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`Olm client ID: ${clientId}, Public Key: ${publicKey}, Relay: ${relay}`
|
||||
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
|
||||
);
|
||||
|
||||
if (!publicKey) {
|
||||
@@ -33,66 +117,16 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the client
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
logger.warn("Client not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (client.exitNodeId) {
|
||||
// TODO: FOR NOW WE ARE JUST HOLEPUNCHING ALL EXIT NODES BUT IN THE FUTURE WE SHOULD HANDLE THIS BETTER
|
||||
|
||||
// Get the exit node
|
||||
const allExitNodes = await listExitNodes(client.orgId, true); // FILTER THE ONLINE ONES
|
||||
|
||||
const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => {
|
||||
return {
|
||||
publicKey: exitNode.publicKey,
|
||||
endpoint: exitNode.endpoint
|
||||
};
|
||||
});
|
||||
|
||||
// Send holepunch message
|
||||
await sendToClient(olm.olmId, {
|
||||
type: "olm/wg/holepunch/all",
|
||||
data: {
|
||||
exitNodes: exitNodesHpData
|
||||
}
|
||||
});
|
||||
|
||||
if (!olmVersion) {
|
||||
// THIS IS FOR BACKWARDS COMPATIBILITY
|
||||
// THE OLDER CLIENTS DID NOT SEND THE VERSION
|
||||
await sendToClient(olm.olmId, {
|
||||
type: "olm/wg/holepunch",
|
||||
data: {
|
||||
serverPubKey: allExitNodes[0].publicKey,
|
||||
endpoint: allExitNodes[0].endpoint
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (olmVersion) {
|
||||
if ((olmVersion && olm.version !== olmVersion) || (olmAgent && olm.agent !== olmAgent)) {
|
||||
await db
|
||||
.update(olms)
|
||||
.set({
|
||||
version: olmVersion
|
||||
version: olmVersion,
|
||||
agent: olmAgent
|
||||
})
|
||||
.where(eq(olms.olmId, olm.olmId));
|
||||
}
|
||||
|
||||
// if (now - (client.lastHolePunch || 0) > 6) {
|
||||
// logger.warn("Client last hole punch is too old, skipping all sites");
|
||||
// return;
|
||||
// }
|
||||
|
||||
if (client.pubKey !== publicKey) {
|
||||
logger.info(
|
||||
"Public key mismatch. Updating public key and clearing session info..."
|
||||
@@ -103,23 +137,26 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
.set({
|
||||
pubKey: publicKey
|
||||
})
|
||||
.where(eq(clients.clientId, olm.clientId));
|
||||
.where(eq(clients.clientId, client.clientId));
|
||||
|
||||
// set isRelay to false for all of the client's sites to reset the connection metadata
|
||||
await db
|
||||
.update(clientSites)
|
||||
.update(clientSitesAssociationsCache)
|
||||
.set({
|
||||
isRelayed: relay == true
|
||||
})
|
||||
.where(eq(clientSites.clientId, olm.clientId));
|
||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||
}
|
||||
|
||||
// Get all sites data
|
||||
const sitesData = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
|
||||
.where(eq(clientSites.clientId, client.clientId));
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||
|
||||
// Prepare an array to store site configurations
|
||||
const siteConfigurations = [];
|
||||
@@ -127,15 +164,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
`Found ${sitesData.length} sites for client ${client.clientId}`
|
||||
);
|
||||
|
||||
if (sitesData.length === 0) {
|
||||
sendToClient(olm.olmId, {
|
||||
type: "olm/register/no-sites",
|
||||
data: {}
|
||||
});
|
||||
// this prevents us from accepting a register from an olm that has not hole punched yet.
|
||||
// the olm will pump the register so we can keep checking
|
||||
// TODO: I still think there is a better way to do this rather than locking it out here but ???
|
||||
if (now - (client.lastHolePunch || 0) > 5 && sitesData.length > 0) {
|
||||
logger.warn(
|
||||
"Client last hole punch is too old and we have sites to send; skipping this register"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Process each site
|
||||
for (const { sites: site } of sitesData) {
|
||||
for (const { sites: site, clientSitesAssociationsCache: association } of sitesData) {
|
||||
if (!site.exitNodeId) {
|
||||
logger.warn(
|
||||
`Site ${site.siteId} does not have exit node, skipping`
|
||||
@@ -145,7 +185,9 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
|
||||
// Validate endpoint and hole punch status
|
||||
if (!site.endpoint) {
|
||||
logger.warn(`In olm register: site ${site.siteId} has no endpoint, skipping`);
|
||||
logger.warn(
|
||||
`In olm register: site ${site.siteId} has no endpoint, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -171,11 +213,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
|
||||
const [clientSite] = await db
|
||||
.select()
|
||||
.from(clientSites)
|
||||
.from(clientSitesAssociationsCache)
|
||||
.where(
|
||||
and(
|
||||
eq(clientSites.clientId, client.clientId),
|
||||
eq(clientSites.siteId, site.siteId)
|
||||
eq(clientSitesAssociationsCache.clientId, client.clientId),
|
||||
eq(clientSitesAssociationsCache.siteId, site.siteId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
@@ -196,7 +238,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
);
|
||||
}
|
||||
|
||||
let endpoint = site.endpoint;
|
||||
let relayEndpoint: string | undefined = undefined;
|
||||
if (relay) {
|
||||
const [exitNode] = await db
|
||||
.select()
|
||||
@@ -207,17 +249,43 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
logger.warn(`Exit node not found for site ${site.siteId}`);
|
||||
continue;
|
||||
}
|
||||
endpoint = `${exitNode.endpoint}:21820`;
|
||||
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
|
||||
}
|
||||
|
||||
const allSiteResources = await db // only get the site resources that this client has access to
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.innerJoin(
|
||||
clientSiteResourcesAssociationsCache,
|
||||
eq(
|
||||
siteResources.siteResourceId,
|
||||
clientSiteResourcesAssociationsCache.siteResourceId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(siteResources.siteId, site.siteId),
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.clientId,
|
||||
client.clientId
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
// Add site configuration to the array
|
||||
siteConfigurations.push({
|
||||
siteId: site.siteId,
|
||||
endpoint: endpoint,
|
||||
// relayEndpoint: relayEndpoint, // this can be undefined now if not relayed // lets not do this for now because it would conflict with the hole punch testing
|
||||
endpoint: site.endpoint,
|
||||
publicKey: site.publicKey,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort,
|
||||
remoteSubnets: site.remoteSubnets
|
||||
remoteSubnets: generateRemoteSubnets(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
),
|
||||
aliases: generateAliasConfig(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
@@ -233,7 +301,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
type: "olm/wg/connect",
|
||||
data: {
|
||||
sites: siteConfigurations,
|
||||
tunnelIP: client.subnet
|
||||
tunnelIP: client.subnet,
|
||||
utilitySubnet: org.utilitySubnet
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { db, exitNodes, sites } from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { clients, clientSites, Olm } from "@server/db";
|
||||
import { clients, clientSitesAssociationsCache, Olm } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { updatePeer } from "../newt/peers";
|
||||
import { updatePeer as newtUpdatePeer } from "../newt/peers";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export const handleOlmRelayMessage: MessageHandler = async (context) => {
|
||||
@@ -67,30 +67,31 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
|
||||
}
|
||||
|
||||
await db
|
||||
.update(clientSites)
|
||||
.update(clientSitesAssociationsCache)
|
||||
.set({
|
||||
isRelayed: true
|
||||
})
|
||||
.where(
|
||||
and(
|
||||
eq(clientSites.clientId, olm.clientId),
|
||||
eq(clientSites.siteId, siteId)
|
||||
eq(clientSitesAssociationsCache.clientId, olm.clientId),
|
||||
eq(clientSitesAssociationsCache.siteId, siteId)
|
||||
)
|
||||
);
|
||||
|
||||
// update the peer on the exit node
|
||||
await updatePeer(siteId, client.pubKey, {
|
||||
endpoint: "" // this removes the endpoint
|
||||
await newtUpdatePeer(siteId, client.pubKey, {
|
||||
endpoint: "" // this removes the endpoint so the exit node knows to relay
|
||||
});
|
||||
|
||||
sendToClient(olm.olmId, {
|
||||
type: "olm/wg/peer/relay",
|
||||
data: {
|
||||
siteId: siteId,
|
||||
endpoint: exitNode.endpoint,
|
||||
publicKey: exitNode.publicKey
|
||||
}
|
||||
});
|
||||
|
||||
return;
|
||||
return {
|
||||
message: {
|
||||
type: "olm/wg/peer/relay",
|
||||
data: {
|
||||
siteId: siteId,
|
||||
relayEndpoint: exitNode.endpoint
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
excludeSender: false
|
||||
};
|
||||
};
|
||||
|
||||
187
server/routers/olm/handleOlmServerPeerAddMessage.ts
Normal file
187
server/routers/olm/handleOlmServerPeerAddMessage.ts
Normal file
@@ -0,0 +1,187 @@
|
||||
import {
|
||||
Client,
|
||||
clientSiteResourcesAssociationsCache,
|
||||
db,
|
||||
ExitNode,
|
||||
Org,
|
||||
orgs,
|
||||
roleClients,
|
||||
roles,
|
||||
siteResources,
|
||||
Transaction,
|
||||
userClients,
|
||||
userOrgs,
|
||||
users
|
||||
} from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import {
|
||||
clients,
|
||||
clientSitesAssociationsCache,
|
||||
exitNodes,
|
||||
Olm,
|
||||
olms,
|
||||
sites
|
||||
} from "@server/db";
|
||||
import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm";
|
||||
import { addPeer, deletePeer } from "../newt/peers";
|
||||
import logger from "@server/logger";
|
||||
import { listExitNodes } from "#dynamic/lib/exitNodes";
|
||||
import {
|
||||
generateAliasConfig,
|
||||
getNextAvailableClientSubnet
|
||||
} from "@server/lib/ip";
|
||||
import { generateRemoteSubnets } from "@server/lib/ip";
|
||||
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
import config from "@server/lib/config";
|
||||
import {
|
||||
addPeer as newtAddPeer,
|
||||
deletePeer as newtDeletePeer
|
||||
} from "@server/routers/newt/peers";
|
||||
|
||||
export const handleOlmServerPeerAddMessage: MessageHandler = async (
|
||||
context
|
||||
) => {
|
||||
logger.info("Handling register olm message!");
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const olm = c as Olm;
|
||||
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
|
||||
if (!olm) {
|
||||
logger.warn("Olm not found");
|
||||
return;
|
||||
}
|
||||
|
||||
const { siteId } = message.data;
|
||||
|
||||
// get the site
|
||||
const [site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
|
||||
if (!site) {
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: Site with ID ${siteId} not found`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!site.endpoint) {
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: Site with ID ${siteId} has no endpoint`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// get the client
|
||||
|
||||
if (!olm.clientId) {
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: Olm with ID ${olm.olmId} has no clientId`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(and(eq(clients.clientId, olm.clientId)))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: Client with ID ${olm.clientId} not found`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!client.pubKey) {
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: Client with ID ${client.clientId} has no public key`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let endpoint: string | null = null;
|
||||
|
||||
// TODO: should we pick only the one from the site its talking to instead of any good current session?
|
||||
const currentSessionSiteAssociationCaches = await db
|
||||
.select()
|
||||
.from(clientSitesAssociationsCache)
|
||||
.where(
|
||||
and(
|
||||
eq(clientSitesAssociationsCache.clientId, client.clientId),
|
||||
isNotNull(clientSitesAssociationsCache.endpoint),
|
||||
eq(clientSitesAssociationsCache.publicKey, client.pubKey) // limit it to the current session its connected with otherwise the endpoint could be stale
|
||||
)
|
||||
);
|
||||
|
||||
// pick an endpoint
|
||||
for (const assoc of currentSessionSiteAssociationCaches) {
|
||||
if (assoc.endpoint) {
|
||||
endpoint = assoc.endpoint;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!endpoint) {
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: No endpoint found for client ${client.clientId}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// NOTE: here we are always starting direct to the peer and will relay later
|
||||
|
||||
await newtAddPeer(siteId, {
|
||||
publicKey: client.pubKey,
|
||||
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
|
||||
endpoint: endpoint // this is the client's endpoint with reference to the site's exit node
|
||||
});
|
||||
|
||||
const allSiteResources = await db // only get the site resources that this client has access to
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.innerJoin(
|
||||
clientSiteResourcesAssociationsCache,
|
||||
eq(
|
||||
siteResources.siteResourceId,
|
||||
clientSiteResourcesAssociationsCache.siteResourceId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(siteResources.siteId, site.siteId),
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.clientId,
|
||||
client.clientId
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
// Return connect message with all site configurations
|
||||
return {
|
||||
message: {
|
||||
type: "olm/wg/peer/add",
|
||||
data: {
|
||||
siteId: site.siteId,
|
||||
endpoint: site.endpoint,
|
||||
publicKey: site.publicKey,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort,
|
||||
remoteSubnets: generateRemoteSubnets(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
),
|
||||
aliases: generateAliasConfig(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
)
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
excludeSender: false
|
||||
};
|
||||
};
|
||||
96
server/routers/olm/handleOlmUnRelayMessage.ts
Normal file
96
server/routers/olm/handleOlmUnRelayMessage.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import { db, exitNodes, sites } from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { clients, clientSitesAssociationsCache, Olm } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { updatePeer as newtUpdatePeer } from "../newt/peers";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const olm = c as Olm;
|
||||
|
||||
logger.info("Handling unrelay olm message!");
|
||||
|
||||
if (!olm) {
|
||||
logger.warn("Olm not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!olm.clientId) {
|
||||
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
|
||||
return;
|
||||
}
|
||||
|
||||
const clientId = olm.clientId;
|
||||
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
logger.warn("Client not found");
|
||||
return;
|
||||
}
|
||||
|
||||
// make sure we hand endpoints for both the site and the client and the lastHolePunch is not too old
|
||||
if (!client.pubKey) {
|
||||
logger.warn("Client has no endpoint or listen port");
|
||||
return;
|
||||
}
|
||||
|
||||
const { siteId } = message.data;
|
||||
|
||||
// Get the site
|
||||
const [site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
|
||||
if (!site) {
|
||||
logger.warn("Site not found or has no exit node");
|
||||
return;
|
||||
}
|
||||
|
||||
const [clientSiteAssociation] = await db
|
||||
.update(clientSitesAssociationsCache)
|
||||
.set({
|
||||
isRelayed: false
|
||||
})
|
||||
.where(
|
||||
and(
|
||||
eq(clientSitesAssociationsCache.clientId, olm.clientId),
|
||||
eq(clientSitesAssociationsCache.siteId, siteId)
|
||||
)
|
||||
)
|
||||
.returning();
|
||||
|
||||
if (!clientSiteAssociation) {
|
||||
logger.warn("Client-Site association not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!clientSiteAssociation.endpoint) {
|
||||
logger.warn("Client-Site association has no endpoint, cannot unrelay");
|
||||
return;
|
||||
}
|
||||
|
||||
// update the peer on the exit node
|
||||
await newtUpdatePeer(siteId, client.pubKey, {
|
||||
endpoint: clientSiteAssociation.endpoint // this is the endpoint of the client to connect directly to the exit node
|
||||
});
|
||||
|
||||
return {
|
||||
message: {
|
||||
type: "olm/wg/peer/unrelay",
|
||||
data: {
|
||||
siteId: siteId,
|
||||
endpoint: site.endpoint
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
excludeSender: false
|
||||
};
|
||||
};
|
||||
@@ -1,5 +1,11 @@
|
||||
export * from "./handleOlmRegisterMessage";
|
||||
export * from "./getOlmToken";
|
||||
export * from "./createOlm";
|
||||
export * from "./createUserOlm";
|
||||
export * from "./handleOlmRelayMessage";
|
||||
export * from "./handleOlmPingMessage";
|
||||
export * from "./handleOlmPingMessage";
|
||||
export * from "./deleteUserOlm";
|
||||
export * from "./listUserOlms";
|
||||
export * from "./deleteUserOlm";
|
||||
export * from "./getUserOlm";
|
||||
export * from "./handleOlmServerPeerAddMessage";
|
||||
export * from "./handleOlmUnRelayMessage";
|
||||
139
server/routers/olm/listUserOlms.ts
Normal file
139
server/routers/olm/listUserOlms.ts
Normal file
@@ -0,0 +1,139 @@
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db } from "@server/db";
|
||||
import { olms } from "@server/db";
|
||||
import { eq, count, desc } from "drizzle-orm";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const querySchema = z.object({
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().positive()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().nonnegative())
|
||||
});
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
userId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "delete",
|
||||
// path: "/user/{userId}/olms",
|
||||
// description: "List all olms for a user.",
|
||||
// tags: [OpenAPITags.User, OpenAPITags.Client],
|
||||
// request: {
|
||||
// query: querySchema,
|
||||
// params: paramsSchema
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export type ListUserOlmsResponse = {
|
||||
olms: Array<{
|
||||
olmId: string;
|
||||
dateCreated: string;
|
||||
version: string | null;
|
||||
name: string | null;
|
||||
clientId: number | null;
|
||||
userId: string | null;
|
||||
}>;
|
||||
pagination: {
|
||||
total: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
};
|
||||
};
|
||||
|
||||
export async function listUserOlms(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedQuery = querySchema.safeParse(req.query);
|
||||
if (!parsedQuery.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedQuery.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { limit, offset } = parsedQuery.data;
|
||||
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { userId } = parsedParams.data;
|
||||
|
||||
// Get total count
|
||||
const [totalCountResult] = await db
|
||||
.select({ count: count() })
|
||||
.from(olms)
|
||||
.where(eq(olms.userId, userId));
|
||||
|
||||
const total = totalCountResult?.count || 0;
|
||||
|
||||
// Get OLMs for the current user
|
||||
const userOlms = await db
|
||||
.select({
|
||||
olmId: olms.olmId,
|
||||
dateCreated: olms.dateCreated,
|
||||
version: olms.version,
|
||||
name: olms.name,
|
||||
clientId: olms.clientId,
|
||||
userId: olms.userId
|
||||
})
|
||||
.from(olms)
|
||||
.where(eq(olms.userId, userId))
|
||||
.orderBy(desc(olms.dateCreated))
|
||||
.limit(limit)
|
||||
.offset(offset);
|
||||
|
||||
return response<ListUserOlmsResponse>(res, {
|
||||
data: {
|
||||
olms: userOlms,
|
||||
pagination: {
|
||||
total,
|
||||
limit,
|
||||
offset
|
||||
}
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Olms retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to list OLMs"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ import { clients, olms, newts, sites } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import logger from "@server/logger";
|
||||
import { Alias } from "yaml";
|
||||
|
||||
export async function addPeer(
|
||||
clientId: number,
|
||||
@@ -10,54 +11,74 @@ export async function addPeer(
|
||||
siteId: number;
|
||||
publicKey: string;
|
||||
endpoint: string;
|
||||
relayEndpoint: string;
|
||||
serverIP: string | null;
|
||||
serverPort: number | null;
|
||||
remoteSubnets: string | null; // optional, comma-separated list of subnets that this site can access
|
||||
}
|
||||
remoteSubnets: string[] | null; // optional, comma-separated list of subnets that this site can access
|
||||
aliases: Alias[];
|
||||
},
|
||||
olmId?: string
|
||||
) {
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
if (!olm) {
|
||||
throw new Error(`Olm with ID ${clientId} not found`);
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
if (!olm) {
|
||||
return; // ignore this because an olm might not be associated with the client anymore
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olm.olmId, {
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/add",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
publicKey: peer.publicKey,
|
||||
endpoint: peer.endpoint,
|
||||
relayEndpoint: peer.relayEndpoint,
|
||||
serverIP: peer.serverIP,
|
||||
serverPort: peer.serverPort,
|
||||
remoteSubnets: peer.remoteSubnets // optional, comma-separated list of subnets that this site can access
|
||||
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
|
||||
aliases: peer.aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`);
|
||||
logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`);
|
||||
}
|
||||
|
||||
export async function deletePeer(clientId: number, siteId: number, publicKey: string) {
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
if (!olm) {
|
||||
throw new Error(`Olm with ID ${clientId} not found`);
|
||||
export async function deletePeer(
|
||||
clientId: number,
|
||||
siteId: number,
|
||||
publicKey: string,
|
||||
olmId?: string
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
if (!olm) {
|
||||
return;
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olm.olmId, {
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/remove",
|
||||
data: {
|
||||
publicKey,
|
||||
siteId: siteId
|
||||
}
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
logger.info(`Deleted peer ${publicKey} from olm ${olm.olmId}`);
|
||||
logger.info(`Deleted peer ${publicKey} from olm ${olmId}`);
|
||||
}
|
||||
|
||||
export async function updatePeer(
|
||||
@@ -66,31 +87,80 @@ export async function updatePeer(
|
||||
siteId: number;
|
||||
publicKey: string;
|
||||
endpoint: string;
|
||||
serverIP: string | null;
|
||||
serverPort: number | null;
|
||||
remoteSubnets?: string | null; // optional, comma-separated list of subnets that
|
||||
}
|
||||
relayEndpoint?: string;
|
||||
serverIP?: string | null;
|
||||
serverPort?: number | null;
|
||||
remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that
|
||||
aliases?: Alias[] | null;
|
||||
},
|
||||
olmId?: string
|
||||
) {
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
if (!olm) {
|
||||
throw new Error(`Olm with ID ${clientId} not found`);
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
if (!olm) {
|
||||
return
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olm.olmId, {
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/update",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
publicKey: peer.publicKey,
|
||||
endpoint: peer.endpoint,
|
||||
relayEndpoint: peer.relayEndpoint,
|
||||
serverIP: peer.serverIP,
|
||||
serverPort: peer.serverPort,
|
||||
remoteSubnets: peer.remoteSubnets
|
||||
remoteSubnets: peer.remoteSubnets,
|
||||
aliases: peer.aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`);
|
||||
logger.info(`Updated peer ${peer.publicKey} on olm ${olmId}`);
|
||||
}
|
||||
|
||||
export async function initPeerAddHandshake(
|
||||
clientId: number,
|
||||
peer: {
|
||||
siteId: number;
|
||||
exitNode: {
|
||||
publicKey: string;
|
||||
endpoint: string;
|
||||
};
|
||||
},
|
||||
olmId?: string
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
if (!olm) {
|
||||
return;
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/holepunch/site/add",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
exitNode: {
|
||||
publicKey: peer.exitNode.publicKey,
|
||||
endpoint: peer.exitNode.endpoint
|
||||
}
|
||||
}
|
||||
}).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
logger.info(`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`);
|
||||
}
|
||||
|
||||
@@ -26,12 +26,13 @@ import { createCustomer } from "#dynamic/lib/billing";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
import { build } from "@server/build";
|
||||
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
|
||||
|
||||
const createOrgSchema = z.strictObject({
|
||||
orgId: z.string(),
|
||||
name: z.string().min(1).max(255),
|
||||
subnet: z.string()
|
||||
});
|
||||
orgId: z.string(),
|
||||
name: z.string().min(1).max(255),
|
||||
subnet: z.string()
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "put",
|
||||
@@ -131,12 +132,16 @@ export async function createOrg(
|
||||
.from(domains)
|
||||
.where(eq(domains.configManaged, true));
|
||||
|
||||
const utilitySubnet =
|
||||
config.getRawConfig().orgs.utility_subnet_group;
|
||||
|
||||
const newOrg = await trx
|
||||
.insert(orgs)
|
||||
.values({
|
||||
orgId,
|
||||
name,
|
||||
subnet,
|
||||
utilitySubnet,
|
||||
createdAt: new Date().toISOString()
|
||||
})
|
||||
.returning();
|
||||
@@ -190,6 +195,7 @@ export async function createOrg(
|
||||
);
|
||||
}
|
||||
|
||||
let ownerUserId: string | null = null;
|
||||
if (req.user) {
|
||||
await trx.insert(userOrgs).values({
|
||||
userId: req.user!.userId,
|
||||
@@ -197,6 +203,7 @@ export async function createOrg(
|
||||
roleId: roleId,
|
||||
isOwner: true
|
||||
});
|
||||
ownerUserId = req.user!.userId;
|
||||
} else {
|
||||
// if org created by root api key, set the server admin as the owner
|
||||
const [serverAdmin] = await trx
|
||||
@@ -216,6 +223,7 @@ export async function createOrg(
|
||||
roleId: roleId,
|
||||
isOwner: true
|
||||
});
|
||||
ownerUserId = serverAdmin.userId;
|
||||
}
|
||||
|
||||
const memberRole = await trx
|
||||
@@ -234,6 +242,8 @@ export async function createOrg(
|
||||
orgId
|
||||
}))
|
||||
);
|
||||
|
||||
await calculateUserClientsForOrgs(ownerUserId, trx);
|
||||
});
|
||||
|
||||
if (!org) {
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, domains, orgDomains, resources } from "@server/db";
|
||||
import {
|
||||
clients,
|
||||
clientSiteResourcesAssociationsCache,
|
||||
clientSitesAssociationsCache,
|
||||
db,
|
||||
domains,
|
||||
olms,
|
||||
orgDomains,
|
||||
resources
|
||||
} from "@server/db";
|
||||
import { newts, newtSessions, orgs, sites, userActions } from "@server/db";
|
||||
import { eq, and, inArray, sql } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
@@ -14,8 +23,8 @@ import { deletePeer } from "../gerbil/peers";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const deleteOrgSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export type DeleteOrgResponse = {};
|
||||
|
||||
@@ -69,41 +78,75 @@ export async function deleteOrg(
|
||||
.where(eq(sites.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
const orgClients = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.orgId, orgId));
|
||||
|
||||
const deletedNewtIds: string[] = [];
|
||||
const olmsToTerminate: string[] = [];
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
if (sites) {
|
||||
for (const site of orgSites) {
|
||||
if (site.pubKey) {
|
||||
if (site.type == "wireguard") {
|
||||
await deletePeer(site.exitNodeId!, site.pubKey);
|
||||
} else if (site.type == "newt") {
|
||||
// get the newt on the site by querying the newt table for siteId
|
||||
const [deletedNewt] = await trx
|
||||
.delete(newts)
|
||||
.where(eq(newts.siteId, site.siteId))
|
||||
.returning();
|
||||
if (deletedNewt) {
|
||||
deletedNewtIds.push(deletedNewt.newtId);
|
||||
for (const site of orgSites) {
|
||||
if (site.pubKey) {
|
||||
if (site.type == "wireguard") {
|
||||
await deletePeer(site.exitNodeId!, site.pubKey);
|
||||
} else if (site.type == "newt") {
|
||||
// get the newt on the site by querying the newt table for siteId
|
||||
const [deletedNewt] = await trx
|
||||
.delete(newts)
|
||||
.where(eq(newts.siteId, site.siteId))
|
||||
.returning();
|
||||
if (deletedNewt) {
|
||||
deletedNewtIds.push(deletedNewt.newtId);
|
||||
|
||||
// delete all of the sessions for the newt
|
||||
await trx
|
||||
.delete(newtSessions)
|
||||
.where(
|
||||
eq(
|
||||
newtSessions.newtId,
|
||||
deletedNewt.newtId
|
||||
)
|
||||
);
|
||||
}
|
||||
// delete all of the sessions for the newt
|
||||
await trx
|
||||
.delete(newtSessions)
|
||||
.where(
|
||||
eq(newtSessions.newtId, deletedNewt.newtId)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`Deleting site ${site.siteId}`);
|
||||
await trx
|
||||
.delete(sites)
|
||||
.where(eq(sites.siteId, site.siteId));
|
||||
}
|
||||
|
||||
logger.info(`Deleting site ${site.siteId}`);
|
||||
await trx.delete(sites).where(eq(sites.siteId, site.siteId));
|
||||
}
|
||||
for (const client of orgClients) {
|
||||
const [olm] = await trx
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, client.clientId))
|
||||
.limit(1);
|
||||
|
||||
if (olm) {
|
||||
olmsToTerminate.push(olm.olmId);
|
||||
}
|
||||
|
||||
logger.info(`Deleting client ${client.clientId}`);
|
||||
await trx
|
||||
.delete(clients)
|
||||
.where(eq(clients.clientId, client.clientId));
|
||||
|
||||
// also delete the associations
|
||||
await trx
|
||||
.delete(clientSiteResourcesAssociationsCache)
|
||||
.where(
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.clientId,
|
||||
client.clientId
|
||||
)
|
||||
);
|
||||
|
||||
await trx
|
||||
.delete(clientSitesAssociationsCache)
|
||||
.where(
|
||||
eq(
|
||||
clientSitesAssociationsCache.clientId,
|
||||
client.clientId
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const allOrgDomains = await trx
|
||||
@@ -150,7 +193,7 @@ export async function deleteOrg(
|
||||
// Send termination messages outside of transaction to prevent blocking
|
||||
for (const newtId of deletedNewtIds) {
|
||||
const payload = {
|
||||
type: `newt/terminate`,
|
||||
type: `newt/wg/terminate`,
|
||||
data: {}
|
||||
};
|
||||
// Don't await this to prevent blocking the response
|
||||
@@ -162,6 +205,18 @@ export async function deleteOrg(
|
||||
});
|
||||
}
|
||||
|
||||
for (const olmId of olmsToTerminate) {
|
||||
sendToClient(olmId, {
|
||||
type: "olm/terminate",
|
||||
data: {}
|
||||
}).catch((error) => {
|
||||
logger.error(
|
||||
"Failed to send termination message to olm:",
|
||||
error
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
|
||||
@@ -130,7 +130,7 @@ export async function getOrgOverview(
|
||||
numSites,
|
||||
numUsers,
|
||||
numResources,
|
||||
isAdmin: role.name === "Admin",
|
||||
isAdmin: role.isAdmin || false,
|
||||
isOwner: req.userOrg?.isOwner || false
|
||||
},
|
||||
success: true,
|
||||
|
||||
161
server/routers/resource/addRoleToResource.ts
Normal file
161
server/routers/resource/addRoleToResource.ts
Normal file
@@ -0,0 +1,161 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, resources } from "@server/db";
|
||||
import { roleResources, roles } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const addRoleToResourceBodySchema = z
|
||||
.object({
|
||||
roleId: z.number().int().positive()
|
||||
})
|
||||
.strict();
|
||||
|
||||
const addRoleToResourceParamsSchema = z
|
||||
.object({
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().positive())
|
||||
})
|
||||
.strict();
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/resource/{resourceId}/roles/add",
|
||||
description: "Add a single role to a resource.",
|
||||
tags: [OpenAPITags.Resource, OpenAPITags.Role],
|
||||
request: {
|
||||
params: addRoleToResourceParamsSchema,
|
||||
body: {
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: addRoleToResourceBodySchema
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function addRoleToResource(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedBody = addRoleToResourceBodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { roleId } = parsedBody.data;
|
||||
|
||||
const parsedParams = addRoleToResourceParamsSchema.safeParse(
|
||||
req.params
|
||||
);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { resourceId } = parsedParams.data;
|
||||
|
||||
// get the resource
|
||||
const [resource] = await db
|
||||
.select()
|
||||
.from(resources)
|
||||
.where(eq(resources.resourceId, resourceId))
|
||||
.limit(1);
|
||||
|
||||
if (!resource) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Resource not found")
|
||||
);
|
||||
}
|
||||
|
||||
// verify the role exists and belongs to the same org
|
||||
const [role] = await db
|
||||
.select()
|
||||
.from(roles)
|
||||
.where(
|
||||
and(
|
||||
eq(roles.roleId, roleId),
|
||||
eq(roles.orgId, resource.orgId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!role) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"Role not found or does not belong to the same organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Check if the role is an admin role
|
||||
if (role.isAdmin) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Admin role cannot be assigned to resources"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Check if role already exists in resource
|
||||
const existingEntry = await db
|
||||
.select()
|
||||
.from(roleResources)
|
||||
.where(
|
||||
and(
|
||||
eq(roleResources.resourceId, resourceId),
|
||||
eq(roleResources.roleId, roleId)
|
||||
)
|
||||
);
|
||||
|
||||
if (existingEntry.length > 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
"Role already assigned to resource"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await db.insert(roleResources).values({
|
||||
roleId,
|
||||
resourceId
|
||||
});
|
||||
|
||||
return response(res, {
|
||||
data: {},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Role added to resource successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
130
server/routers/resource/addUserToResource.ts
Normal file
130
server/routers/resource/addUserToResource.ts
Normal file
@@ -0,0 +1,130 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, resources } from "@server/db";
|
||||
import { userResources } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const addUserToResourceBodySchema = z
|
||||
.object({
|
||||
userId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
const addUserToResourceParamsSchema = z
|
||||
.object({
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().positive())
|
||||
})
|
||||
.strict();
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/resource/{resourceId}/users/add",
|
||||
description: "Add a single user to a resource.",
|
||||
tags: [OpenAPITags.Resource, OpenAPITags.User],
|
||||
request: {
|
||||
params: addUserToResourceParamsSchema,
|
||||
body: {
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: addUserToResourceBodySchema
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function addUserToResource(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedBody = addUserToResourceBodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { userId } = parsedBody.data;
|
||||
|
||||
const parsedParams = addUserToResourceParamsSchema.safeParse(
|
||||
req.params
|
||||
);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { resourceId } = parsedParams.data;
|
||||
|
||||
// get the resource
|
||||
const [resource] = await db
|
||||
.select()
|
||||
.from(resources)
|
||||
.where(eq(resources.resourceId, resourceId))
|
||||
.limit(1);
|
||||
|
||||
if (!resource) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Resource not found")
|
||||
);
|
||||
}
|
||||
|
||||
// Check if user already exists in resource
|
||||
const existingEntry = await db
|
||||
.select()
|
||||
.from(userResources)
|
||||
.where(
|
||||
and(
|
||||
eq(userResources.resourceId, resourceId),
|
||||
eq(userResources.userId, userId)
|
||||
)
|
||||
);
|
||||
|
||||
if (existingEntry.length > 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
"User already assigned to resource"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await db.insert(userResources).values({
|
||||
userId,
|
||||
resourceId
|
||||
});
|
||||
|
||||
return response(res, {
|
||||
data: {},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "User added to resource successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,4 +25,8 @@ export * from "./getUserResources";
|
||||
export * from "./setResourceHeaderAuth";
|
||||
export * from "./addEmailToResourceWhitelist";
|
||||
export * from "./removeEmailFromResourceWhitelist";
|
||||
export * from "./addRoleToResource";
|
||||
export * from "./removeRoleFromResource";
|
||||
export * from "./addUserToResource";
|
||||
export * from "./removeUserFromResource";
|
||||
export * from "./listAllResourceNames";
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user