Merge branch 'hp-multi-client' into auth-providers-clients

This commit is contained in:
miloschwartz
2025-04-20 16:15:40 -04:00
70 changed files with 27368 additions and 159 deletions

View File

@@ -64,6 +64,11 @@ export enum ActionsEnum {
deleteResourceRule = "deleteResourceRule",
listResourceRules = "listResourceRules",
updateResourceRule = "updateResourceRule",
createClient = "createClient",
deleteClient = "deleteClient",
updateClient = "updateClient",
listClients = "listClients",
getClient = "getClient",
listOrgDomains = "listOrgDomains",
createNewt = "createNewt"
}

View File

@@ -0,0 +1,72 @@
import {
encodeHexLowerCase,
} from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { Olm, olms, olmSessions, OlmSession } from "@server/db/schema";
import db from "@server/db";
import { eq } from "drizzle-orm";
export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
export async function createOlmSession(
token: string,
olmId: string,
): Promise<OlmSession> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
);
const session: OlmSession = {
sessionId: sessionId,
olmId,
expiresAt: new Date(Date.now() + EXPIRES).getTime(),
};
await db.insert(olmSessions).values(session);
return session;
}
export async function validateOlmSessionToken(
token: string,
): Promise<SessionValidationResult> {
const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(token)),
);
const result = await db
.select({ olm: olms, session: olmSessions })
.from(olmSessions)
.innerJoin(olms, eq(olmSessions.olmId, olms.olmId))
.where(eq(olmSessions.sessionId, sessionId));
if (result.length < 1) {
return { session: null, olm: null };
}
const { olm, session } = result[0];
if (Date.now() >= session.expiresAt) {
await db
.delete(olmSessions)
.where(eq(olmSessions.sessionId, session.sessionId));
return { session: null, olm: null };
}
if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
session.expiresAt = new Date(
Date.now() + EXPIRES,
).getTime();
await db
.update(olmSessions)
.set({
expiresAt: session.expiresAt,
})
.where(eq(olmSessions.sessionId, session.sessionId));
}
return { session, olm };
}
export async function invalidateOlmSession(sessionId: string): Promise<void> {
await db.delete(olmSessions).where(eq(olmSessions.sessionId, sessionId));
}
export async function invalidateAllOlmSessions(olmId: string): Promise<void> {
await db.delete(olmSessions).where(eq(olmSessions.olmId, olmId));
}
export type SessionValidationResult =
| { session: OlmSession; olm: Olm }
| { session: null; olm: null };

View File

@@ -11,7 +11,8 @@ export const domains = sqliteTable("domains", {
export const orgs = sqliteTable("orgs", {
orgId: text("orgId").primaryKey(),
name: text("name").notNull()
name: text("name").notNull(),
subnet: text("subnet").notNull(),
});
export const orgDomains = sqliteTable("orgDomains", {
@@ -41,7 +42,14 @@ export const sites = sqliteTable("sites", {
megabytesOut: integer("bytesOut"),
lastBandwidthUpdate: text("lastBandwidthUpdate"),
type: text("type").notNull(), // "newt" or "wireguard"
online: integer("online", { mode: "boolean" }).notNull().default(false)
online: integer("online", { mode: "boolean" }).notNull().default(false),
// exit node stuff that is how to connect to the site when it has a wg server
address: text("address"), // this is the address of the wireguard interface in newt
endpoint: text("endpoint"), // this is how to reach gerbil externally - gets put into the wireguard config
publicKey: text("pubicKey"),
lastHolePunch: integer("lastHolePunch"),
listenPort: integer("listenPort")
});
export const resources = sqliteTable("resources", {
@@ -136,6 +144,48 @@ export const newts = sqliteTable("newt", {
})
});
export const clients = sqliteTable("clients", {
clientId: integer("id").primaryKey({ autoIncrement: true }),
orgId: text("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
})
.notNull(),
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null"
}),
name: text("name").notNull(),
pubKey: text("pubKey"),
subnet: text("subnet").notNull(),
megabytesIn: integer("bytesIn"),
megabytesOut: integer("bytesOut"),
lastBandwidthUpdate: text("lastBandwidthUpdate"),
lastPing: text("lastPing"),
type: text("type").notNull(), // "olm"
online: integer("online", { mode: "boolean" }).notNull().default(false),
endpoint: text("endpoint"),
lastHolePunch: integer("lastHolePunch")
});
export const clientSites = sqliteTable("clientSites", {
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" }),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
isRelayed: integer("isRelayed", { mode: "boolean" }).notNull().default(false)
});
export const olms = sqliteTable("olms", {
olmId: text("id").primaryKey(),
secretHash: text("secretHash").notNull(),
dateCreated: text("dateCreated").notNull(),
clientId: integer("clientId").references(() => clients.clientId, {
onDelete: "cascade"
})
});
export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", {
codeId: integer("id").primaryKey({ autoIncrement: true }),
userId: text("userId")
@@ -160,6 +210,14 @@ export const newtSessions = sqliteTable("newtSession", {
expiresAt: integer("expiresAt").notNull()
});
export const olmSessions = sqliteTable("clientSession", {
sessionId: text("id").primaryKey(),
olmId: text("olmId")
.notNull()
.references(() => olms.olmId, { onDelete: "cascade" }),
expiresAt: integer("expiresAt").notNull()
});
export const userOrgs = sqliteTable("userOrgs", {
userId: text("userId")
.notNull()
@@ -255,6 +313,24 @@ export const userSites = sqliteTable("userSites", {
.references(() => sites.siteId, { onDelete: "cascade" })
});
export const userClients = sqliteTable("userClients", {
userId: text("userId")
.notNull()
.references(() => users.userId, { onDelete: "cascade" }),
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" })
});
export const roleClients = sqliteTable("roleClients", {
roleId: integer("roleId")
.notNull()
.references(() => roles.roleId, { onDelete: "cascade" }),
clientId: integer("clientId")
.notNull()
.references(() => clients.clientId, { onDelete: "cascade" })
});
export const roleResources = sqliteTable("roleResources", {
roleId: integer("roleId")
.notNull()
@@ -473,6 +549,8 @@ export type Target = InferSelectModel<typeof targets>;
export type Session = InferSelectModel<typeof sessions>;
export type Newt = InferSelectModel<typeof newts>;
export type NewtSession = InferSelectModel<typeof newtSessions>;
export type Olm = InferSelectModel<typeof olms>;
export type OlmSession = InferSelectModel<typeof olmSessions>;
export type EmailVerificationCode = InferSelectModel<
typeof emailVerificationCodes
>;
@@ -497,6 +575,10 @@ export type ResourceAccessToken = InferSelectModel<typeof resourceAccessToken>;
export type ResourceWhitelist = InferSelectModel<typeof resourceWhitelist>;
export type VersionMigration = InferSelectModel<typeof versionMigrations>;
export type ResourceRule = InferSelectModel<typeof resourceRules>;
export type Client = InferSelectModel<typeof clients>;
export type ClientSite = InferSelectModel<typeof clientSites>;
export type RoleClient = InferSelectModel<typeof roleClients>;
export type UserClient = InferSelectModel<typeof userClients>;
export type Domain = InferSelectModel<typeof domains>;
export type SupporterKey = InferSelectModel<typeof supporterKey>;
export type Idp = InferSelectModel<typeof idp>;

View File

@@ -12,7 +12,6 @@ import { passwordSchema } from "@server/auth/passwordSchema";
import stoi from "./stoi";
import db from "@server/db";
import { SupporterKey, supporterKey } from "@server/db/schemas";
import { suppressDeprecationWarnings } from "moment";
import { eq } from "drizzle-orm";
const portSchema = z.number().positive().gt(0).lte(65535);
@@ -122,6 +121,10 @@ const configSchema = z.object({
block_size: z.number().positive().gt(0),
site_block_size: z.number().positive().gt(0)
}),
orgs: z.object({
block_size: z.number().positive().gt(0),
subnet_group: z.string(),
}),
rate_limits: z.object({
global: z.object({
window_minutes: z.number().positive().gt(0),

View File

@@ -4,7 +4,14 @@ import { assertEquals } from "@test/assert";
// Test cases
function testFindNextAvailableCidr() {
console.log("Running findNextAvailableCidr tests...");
// Test 0: Basic IPv4 allocation with a subnet in the wrong range
{
const existing = ["100.90.130.1/30", "100.90.128.4/30"];
const result = findNextAvailableCidr(existing, 30, "100.90.130.1/24");
assertEquals(result, "100.90.130.4/30", "Basic IPv4 allocation failed");
}
// Test 1: Basic IPv4 allocation
{
const existing = ["10.0.0.0/16", "10.1.0.0/16"];
@@ -26,6 +33,12 @@ function testFindNextAvailableCidr() {
assertEquals(result, null, "No available space test failed");
}
// Test 4: Empty existing
{
const existing: string[] = [];
const result = findNextAvailableCidr(existing, 30, "10.0.0.0/8");
assertEquals(result, "10.0.0.0/30", "Empty existing test failed");
}
// // Test 4: IPv6 allocation
// {
// const existing = ["2001:db8::/32", "2001:db8:1::/32"];

View File

@@ -1,3 +1,8 @@
import db from "@server/db";
import { clients, orgs, sites } from "@server/db/schema";
import { and, eq, isNotNull } from "drizzle-orm";
import config from "@server/lib/config";
interface IPRange {
start: bigint;
end: bigint;
@@ -132,7 +137,6 @@ export function findNextAvailableCidr(
blockSize: number,
startCidr?: string
): string | null {
if (!startCidr && existingCidrs.length === 0) {
return null;
}
@@ -150,40 +154,47 @@ export function findNextAvailableCidr(
existingCidrs.some(cidr => detectIpVersion(cidr.split('/')[0]) !== version)) {
throw new Error('All CIDRs must be of the same IP version');
}
// Extract the network part from startCidr to ensure we stay in the right subnet
const startCidrRange = cidrToRange(startCidr);
// Convert existing CIDRs to ranges and sort them
const existingRanges = existingCidrs
.map(cidr => cidrToRange(cidr))
.sort((a, b) => (a.start < b.start ? -1 : 1));
// Calculate block size
const maxPrefix = version === 4 ? 32 : 128;
const blockSizeBigInt = BigInt(1) << BigInt(maxPrefix - blockSize);
// Start from the beginning of the given CIDR
let current = cidrToRange(startCidr).start;
const maxIp = cidrToRange(startCidr).end;
let current = startCidrRange.start;
const maxIp = startCidrRange.end;
// Iterate through existing ranges
for (let i = 0; i <= existingRanges.length; i++) {
const nextRange = existingRanges[i];
// Align current to block size
const alignedCurrent = current + ((blockSizeBigInt - (current % blockSizeBigInt)) % blockSizeBigInt);
// Check if we've gone beyond the maximum allowed IP
if (alignedCurrent + blockSizeBigInt - BigInt(1) > maxIp) {
return null;
}
// If we're at the end of existing ranges or found a gap
if (!nextRange || alignedCurrent + blockSizeBigInt - BigInt(1) < nextRange.start) {
return `${bigIntToIp(alignedCurrent, version)}/${blockSize}`;
}
// Move current pointer to after the current range
current = nextRange.end + BigInt(1);
// If next range overlaps with our search space, move past it
if (nextRange.end >= startCidrRange.start && nextRange.start <= maxIp) {
// Move current pointer to after the current range
current = nextRange.end + BigInt(1);
}
}
return null;
}
@@ -204,4 +215,63 @@ export function isIpInCidr(ip: string, cidr: string): boolean {
const ipBigInt = ipToBigInt(ip);
const range = cidrToRange(cidr);
return ipBigInt >= range.start && ipBigInt <= range.end;
}
export async function getNextAvailableClientSubnet(orgId: string): Promise<string> {
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId));
const existingAddressesSites = await db
.select({
address: sites.address
})
.from(sites)
.where(and(isNotNull(sites.address), eq(sites.orgId, orgId)));
const existingAddressesClients = await db
.select({
address: clients.subnet
})
.from(clients)
.where(and(isNotNull(clients.subnet), eq(clients.orgId, orgId)));
const addresses = [
...existingAddressesSites.map((site) => `${site.address?.split("/")[0]}/32`), // we are overriding the 32 so that we pick individual addresses in the subnet of the org for the site and the client even though they are stored with the /block_size of the org
...existingAddressesClients.map((client) => `${client.address.split("/")}/32`)
].filter((address) => address !== null) as string[];
let subnet = findNextAvailableCidr(
addresses,
32,
org.subnet
); // pick the sites address in the org
if (!subnet) {
throw new Error("No available subnets remaining in space");
}
return subnet;
}
export async function getNextAvailableOrgSubnet(): Promise<string> {
const existingAddresses = await db
.select({
subnet: orgs.subnet
})
.from(orgs)
.where(isNotNull(orgs.subnet));
const addresses = existingAddresses.map((org) => org.subnet);
let subnet = findNextAvailableCidr(
addresses,
config.getRawConfig().orgs.block_size,
config.getRawConfig().orgs.subnet_group
);
if (!subnet) {
throw new Error("No available subnets remaining in space");
}
return subnet;
}

View File

@@ -16,3 +16,4 @@ export * from "./verifyUserInRole";
export * from "./verifyAccessTokenAccess";
export * from "./verifyUserIsServerAdmin";
export * from "./verifyIsLoggedInUser";
export * from "./verifyClientAccess";

View File

@@ -0,0 +1,131 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { userOrgs, clients, roleClients, userClients } from "@server/db/schema";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
export async function verifyClientAccess(
req: Request,
res: Response,
next: NextFunction
) {
const userId = req.user!.userId; // Assuming you have user information in the request
const clientId = parseInt(
req.params.clientId || req.body.clientId || req.query.clientId
);
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
);
}
if (isNaN(clientId)) {
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid client ID"));
}
try {
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
if (!client.orgId) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Client with ID ${clientId} does not have an organization ID`
)
);
}
if (!req.userOrg) {
// Get user's role ID in the organization
const userOrgRole = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
eq(userOrgs.orgId, client.orgId)
)
)
.limit(1);
req.userOrg = userOrgRole[0];
}
if (!req.userOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
)
);
}
const userOrgRoleId = req.userOrg.roleId;
req.userOrgRoleId = userOrgRoleId;
req.userOrgId = client.orgId;
// Check role-based site access first
const [roleClientAccess] = await db
.select()
.from(roleClients)
.where(
and(
eq(roleClients.clientId, clientId),
eq(roleClients.roleId, userOrgRoleId)
)
)
.limit(1);
if (roleClientAccess) {
// User has access to the site through their role
return next();
}
// If role doesn't have access, check user-specific site access
const [userClientAccess] = await db
.select()
.from(userClients)
.where(
and(
eq(userClients.userId, userId),
eq(userClients.clientId, clientId)
)
)
.limit(1);
if (userClientAccess) {
// User has direct access to the site
return next();
}
// If we reach here, the user doesn't have access to the site
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this client"
)
);
} catch (error) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying site access"
)
);
}
}

View File

@@ -0,0 +1,236 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import {
roles,
Client,
clients,
roleClients,
userClients,
olms,
clientSites,
exitNodes,
orgs,
sites
} from "@server/db/schema";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import moment from "moment";
import { hashPassword } from "@server/auth/password";
import { isValidCIDR, isValidIP } from "@server/lib/validators";
import { isIpInCidr } from "@server/lib/ip";
const createClientParamsSchema = z
.object({
orgId: z.string()
})
.strict();
const createClientSchema = z
.object({
name: z.string().min(1).max(255),
siteIds: z.array(z.number().int().positive()),
olmId: z.string(),
secret: z.string(),
subnet: z.string(),
type: z.enum(["olm"])
})
.strict();
export type CreateClientBody = z.infer<typeof createClientSchema>;
export type CreateClientResponse = Client;
export async function createClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = createClientSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { name, type, siteIds, olmId, secret, subnet } = parsedBody.data;
const parsedParams = createClientParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
if (!req.userOrgRoleId) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);
}
if (!isValidIP(subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid subnet format. Please provide a valid CIDR notation."
)
);
}
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId));
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Organization with ID ${orgId} not found`
)
);
}
if (!isIpInCidr(subnet, org.subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"IP is not in the CIDR range of the subnet."
)
);
}
const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
// make sure the subnet is unique
const subnetExistsClients = await db
.select()
.from(clients)
.where(eq(clients.subnet, updatedSubnet))
.limit(1);
if (subnetExistsClients.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
const subnetExistsSites = await db
.select()
.from(sites)
.where(eq(sites.address, updatedSubnet))
.limit(1);
if (subnetExistsSites.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
await db.transaction(async (trx) => {
// TODO: more intelligent way to pick the exit node
// make sure there is an exit node by counting the exit nodes table
const nodes = await db.select().from(exitNodes);
if (nodes.length === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"No exit nodes available"
)
);
}
// get the first exit node
const exitNode = nodes[0];
const adminRole = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (adminRole.length === 0) {
trx.rollback();
return next(
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
);
}
const [newClient] = await trx
.insert(clients)
.values({
exitNodeId: exitNode.exitNodeId,
orgId,
name,
subnet: updatedSubnet,
type
})
.returning();
await trx.insert(roleClients).values({
roleId: adminRole[0].roleId,
clientId: newClient.clientId
});
if (req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the site
trx.insert(userClients).values({
userId: req.user?.userId!,
clientId: newClient.clientId
});
}
// Create site to client associations
if (siteIds && siteIds.length > 0) {
await trx.insert(clientSites).values(
siteIds.map((siteId) => ({
clientId: newClient.clientId,
siteId
}))
);
}
const secretHash = await hashPassword(secret);
await trx.insert(olms).values({
olmId,
secretHash,
clientId: newClient.clientId,
dateCreated: moment().toISOString()
});
return response<CreateClientResponse>(res, {
data: newClient,
success: true,
error: false,
message: "Site created successfully",
status: HttpCode.CREATED
});
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,76 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients, clientSites } from "@server/db/schema";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
const deleteClientSchema = z
.object({
clientId: z.string().transform(Number).pipe(z.number().int().positive())
})
.strict();
export async function deleteClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = deleteClientSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
await db.transaction(async (trx) => {
// Delete the client-site associations first
await trx
.delete(clientSites)
.where(eq(clientSites.clientId, clientId));
// Then delete the client itself
await trx
.delete(clients)
.where(eq(clients.clientId, clientId));
});
return response(res, {
data: null,
success: true,
error: false,
message: "Client deleted successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,74 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients } from "@server/db/schema";
import { eq, and } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import stoi from "@server/lib/stoi";
import { fromError } from "zod-validation-error";
const getClientSchema = z
.object({
clientId: z
.string()
.transform(stoi)
.pipe(z.number().int().positive()),
orgId: z.string().optional()
})
.strict();
async function query(clientId: number) {
const [res] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
return res;
}
export type GetClientResponse = NonNullable<Awaited<ReturnType<typeof query>>>;
export async function getClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = getClientSchema.safeParse(req.params);
if (!parsedParams.success) {
logger.error(
`Error parsing params: ${fromError(parsedParams.error).toString()}`
);
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
const client = await query(clientId);
if (!client) {
return next(createHttpError(HttpCode.NOT_FOUND, "Client not found"));
}
return response<GetClientResponse>(res, {
data: client,
success: true,
error: false,
message: "Client retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,6 @@
export * from "./pickClientDefaults";
export * from "./createClient";
export * from "./deleteClient";
export * from "./listClients";
export * from "./updateClient";
export * from "./getClient";

View File

@@ -0,0 +1,208 @@
import { db } from "@server/db";
import {
clients,
orgs,
roleClients,
sites,
userClients,
clientSites
} from "@server/db/schema";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
import { and, count, eq, inArray, or, sql } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
const listClientsParamsSchema = z
.object({
orgId: z.string()
})
.strict();
const listClientsSchema = z.object({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.number().int().positive()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.number().int().nonnegative())
});
function queryClients(orgId: string, accessibleClientIds: number[]) {
return db
.select({
clientId: clients.clientId,
orgId: clients.orgId,
name: clients.name,
pubKey: clients.pubKey,
subnet: clients.subnet,
megabytesIn: clients.megabytesIn,
megabytesOut: clients.megabytesOut,
orgName: orgs.name,
type: clients.type,
online: clients.online
})
.from(clients)
.leftJoin(orgs, eq(clients.orgId, orgs.orgId))
.where(
and(
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId)
)
);
}
async function getSiteAssociations(clientIds: number[]) {
if (clientIds.length === 0) return [];
return db
.select({
clientId: clientSites.clientId,
siteId: clientSites.siteId,
siteName: sites.name,
siteNiceId: sites.niceId
})
.from(clientSites)
.leftJoin(sites, eq(clientSites.siteId, sites.siteId))
.where(inArray(clientSites.clientId, clientIds));
}
export type ListClientsResponse = {
clients: Array<Awaited<ReturnType<typeof queryClients>>[0] & { sites: Array<{
siteId: number;
siteName: string | null;
siteNiceId: string | null;
}> }>;
pagination: { total: number; limit: number; offset: number };
};
export async function listClients(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedQuery = listClientsSchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error)
)
);
}
const { limit, offset } = parsedQuery.data;
const parsedParams = listClientsParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error)
)
);
}
const { orgId } = parsedParams.data;
if (orgId && orgId !== req.userOrgId) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
)
);
}
const accessibleClients = await db
.select({
clientId: sql<number>`COALESCE(${userClients.clientId}, ${roleClients.clientId})`
})
.from(userClients)
.fullJoin(
roleClients,
eq(userClients.clientId, roleClients.clientId)
)
.where(
or(
eq(userClients.userId, req.user!.userId),
eq(roleClients.roleId, req.userOrgRoleId!)
)
);
const accessibleClientIds = accessibleClients.map(
(client) => client.clientId
);
const baseQuery = queryClients(orgId, accessibleClientIds);
// Get client count
const countQuery = db
.select({ count: count() })
.from(clients)
.where(
and(
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId)
)
);
const clientsList = await baseQuery.limit(limit).offset(offset);
const totalCountResult = await countQuery;
const totalCount = totalCountResult[0].count;
// Get associated sites for all clients
const clientIds = clientsList.map(client => client.clientId);
const siteAssociations = await getSiteAssociations(clientIds);
// Group site associations by client ID
const sitesByClient = siteAssociations.reduce((acc, association) => {
if (!acc[association.clientId]) {
acc[association.clientId] = [];
}
acc[association.clientId].push({
siteId: association.siteId,
siteName: association.siteName,
siteNiceId: association.siteNiceId
});
return acc;
}, {} as Record<number, Array<{
siteId: number;
siteName: string | null;
siteNiceId: string | null;
}>>);
// Merge clients with their site associations
const clientsWithSites = clientsList.map(client => ({
...client,
sites: sitesByClient[client.clientId] || []
}));
return response<ListClientsResponse>(res, {
data: {
clients: clientsWithSites,
pagination: {
total: totalCount,
limit,
offset
}
},
success: true,
error: false,
message: "Clients retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,73 @@
import { Request, Response, NextFunction } from "express";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { generateId } from "@server/auth/sessions/app";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
import { z } from "zod";
import { fromError } from "zod-validation-error";
export type PickClientDefaultsResponse = {
olmId: string;
olmSecret: string;
subnet: string;
};
const pickClientDefaultsSchema = z
.object({
orgId: z.string()
})
.strict();
export async function pickClientDefaults(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = pickClientDefaultsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const olmId = generateId(15);
const secret = generateId(48);
const newSubnet = await getNextAvailableClientSubnet(orgId);
if (!newSubnet) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"No available subnet found"
)
);
}
const subnet = newSubnet.split("/")[0];
return response<PickClientDefaultsResponse>(res, {
data: {
olmId: olmId,
olmSecret: secret,
subnet: subnet
},
success: true,
error: false,
message: "Organization retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,124 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import {
clients,
clientSites
} from "@server/db/schema";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
const updateClientParamsSchema = z
.object({
clientId: z.string().transform(Number).pipe(z.number().int().positive())
})
.strict();
const updateClientSchema = z
.object({
name: z.string().min(1).max(255).optional(),
siteIds: z.array(z.string().transform(Number).pipe(z.number())).optional()
})
.strict();
export type UpdateClientBody = z.infer<typeof updateClientSchema>;
export async function updateClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = updateClientSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { name, siteIds } = parsedBody.data;
const parsedParams = updateClientParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
// Fetch the client to make sure it exists and the user has access to it
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
await db.transaction(async (trx) => {
// Update client name if provided
if (name) {
await trx
.update(clients)
.set({ name })
.where(eq(clients.clientId, clientId));
}
// Update site associations if provided
if (siteIds) {
// Delete existing site associations
await trx
.delete(clientSites)
.where(eq(clientSites.clientId, clientId));
// Create new site associations
if (siteIds.length > 0) {
await trx.insert(clientSites).values(
siteIds.map(siteId => ({
clientId,
siteId
}))
);
}
}
// Fetch the updated client
const [updatedClient] = await trx
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
return response(res, {
data: updatedClient,
success: true,
error: false,
message: "Client updated successfully",
status: HttpCode.OK
});
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -8,6 +8,7 @@ import * as target from "./target";
import * as user from "./user";
import * as auth from "./auth";
import * as role from "./role";
import * as client from "./client";
import * as supporterKey from "./supporterKey";
import * as accessToken from "./accessToken";
import * as idp from "./idp";
@@ -26,12 +27,14 @@ import {
verifyUserAccess,
getUserOrgs,
verifyUserIsServerAdmin,
verifyIsLoggedInUser
verifyIsLoggedInUser,
verifyClientAccess,
} from "@server/middlewares";
import { verifyUserHasAction } from "../middlewares/verifyUserHasAction";
import { ActionsEnum } from "@server/auth/actions";
import { verifyUserIsOrgOwner } from "../middlewares/verifyUserIsOrgOwner";
import { createNewt, getToken } from "./newt";
import { createNewt, getNewtToken } from "./newt";
import { getOlmToken } from "./olm";
import rateLimit from "express-rate-limit";
import createHttpError from "http-errors";
@@ -46,6 +49,10 @@ unauthenticated.get("/", (_, res) => {
export const authenticated = Router();
authenticated.use(verifySessionUserMiddleware);
authenticated.get(
"/pick-org-defaults",
org.pickOrgDefaults
);
authenticated.get("/org/checkId", org.checkId);
authenticated.put("/org", getUserOrgs, org.createOrg);
@@ -102,6 +109,49 @@ authenticated.get(
verifyUserHasAction(ActionsEnum.getSite),
site.getSite
);
authenticated.get(
"/org/:orgId/pick-client-defaults",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createClient),
client.pickClientDefaults
);
authenticated.get(
"/org/:orgId/clients",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.listClients),
client.listClients
);
authenticated.get(
"/org/:orgId/client/:clientId",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.getClient),
client.getClient
);
authenticated.put(
"/org/:orgId/client",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createClient),
client.createClient
);
authenticated.delete(
"/client/:clientId",
verifyClientAccess,
verifyUserHasAction(ActionsEnum.deleteClient),
client.deleteClient
);
authenticated.post(
"/client/:clientId",
verifyClientAccess, // this will check if the user has access to the client
verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client
client.updateClient
);
// authenticated.get(
// "/site/:siteId/roles",
// verifySiteAccess,
@@ -559,7 +609,8 @@ authRouter.use(
authRouter.put("/signup", auth.signup);
authRouter.post("/login", auth.login);
authRouter.post("/logout", auth.logout);
authRouter.post("/newt/get-token", getToken);
authRouter.post("/newt/get-token", getNewtToken);
authRouter.post("/olm/get-token", getOlmToken);
authRouter.post("/2fa/enable", verifySessionUserMiddleware, auth.verifyTotp);
authRouter.post(

View File

@@ -0,0 +1,160 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { clients, exitNodes, newts, olms, Site, sites, clientSites } from "@server/db/schema";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
// Define Zod schema for request validation
const getAllRelaysSchema = z.object({
publicKey: z.string().optional(),
});
// Type for peer destination
interface PeerDestination {
destinationIP: string;
destinationPort: number;
}
// Updated mappings type to support multiple destinations per endpoint
interface ProxyMapping {
destinations: PeerDestination[];
}
export async function getAllRelays(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
// Validate request parameters
const parsedParams = getAllRelaysSchema.safeParse(req.body);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { publicKey } = parsedParams.data;
if (!publicKey) {
return next(createHttpError(HttpCode.BAD_REQUEST, 'publicKey is required'));
}
// Fetch exit node
let [exitNode] = await db.select().from(exitNodes).where(eq(exitNodes.publicKey, publicKey));
if (!exitNode) {
return next(createHttpError(HttpCode.NOT_FOUND, "Exit node not found"));
}
// Fetch sites for this exit node
const sitesRes = await db.select().from(sites).where(eq(sites.exitNodeId, exitNode.exitNodeId));
if (sitesRes.length === 0) {
return res.status(HttpCode.OK).send({
mappings: {}
});
}
// Initialize mappings object for multi-peer support
let mappings: { [key: string]: ProxyMapping } = {};
// Process each site
for (const site of sitesRes) {
if (!site.endpoint || !site.subnet || !site.listenPort) {
continue;
}
// Find all clients associated with this site through clientSites
const clientSitesRes = await db
.select()
.from(clientSites)
.where(eq(clientSites.siteId, site.siteId));
for (const clientSite of clientSitesRes) {
// Get client information
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientSite.clientId));
if (!client || !client.endpoint) {
continue;
}
// Add this site as a destination for the client
if (!mappings[client.endpoint]) {
mappings[client.endpoint] = { destinations: [] };
}
// Add site as a destination for this client
const destination: PeerDestination = {
destinationIP: site.subnet.split("/")[0],
destinationPort: site.listenPort
};
// Check if this destination is already in the array to avoid duplicates
const isDuplicate = mappings[client.endpoint].destinations.some(
dest => dest.destinationIP === destination.destinationIP &&
dest.destinationPort === destination.destinationPort
);
if (!isDuplicate) {
mappings[client.endpoint].destinations.push(destination);
}
}
// Also handle site-to-site communication (all sites in the same org)
if (site.orgId) {
const orgSites = await db
.select()
.from(sites)
.where(eq(sites.orgId, site.orgId));
for (const peer of orgSites) {
// Skip self
if (peer.siteId === site.siteId || !peer.endpoint || !peer.subnet || !peer.listenPort) {
continue;
}
// Add peer site as a destination for this site
if (!mappings[site.endpoint]) {
mappings[site.endpoint] = { destinations: [] };
}
const destination: PeerDestination = {
destinationIP: peer.subnet.split("/")[0],
destinationPort: peer.listenPort
};
// Check for duplicates
const isDuplicate = mappings[site.endpoint].destinations.some(
dest => dest.destinationIP === destination.destinationIP &&
dest.destinationPort === destination.destinationPort
);
if (!isDuplicate) {
mappings[site.endpoint].destinations.push(destination);
}
}
}
}
logger.debug(`Returning mappings for ${Object.keys(mappings).length} endpoints`);
return res.status(HttpCode.OK).send({ mappings });
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred..."
)
);
}
}

View File

@@ -79,14 +79,12 @@ export async function getConfig(req: Request, res: Response, next: NextFunction)
}
// Fetch sites for this exit node
const sitesRes = await db.query.sites.findMany({
where: eq(sites.exitNodeId, exitNode[0].exitNodeId),
});
const sitesRes = await db.select().from(sites).where(eq(sites.exitNodeId, exitNode[0].exitNodeId));
const peers = await Promise.all(sitesRes.map(async (site) => {
return {
publicKey: site.pubKey,
allowedIps: await getAllowedIps(site.siteId)
allowedIps: await getAllowedIps(site.siteId) // put 0.0.0.0/0 for now
};
}));

View File

@@ -1,2 +1,4 @@
export * from "./getConfig";
export * from "./receiveBandwidth";
export * from "./updateHolePunch";
export * from "./getAllRelays";

View File

@@ -30,12 +30,13 @@ export const receiveBandwidth = async (
const { publicKey, bytesIn, bytesOut } = peer;
// Find the site by public key
const site = await trx.query.sites.findFirst({
where: eq(sites.pubKey, publicKey)
});
const [site] = await trx
.select()
.from(sites)
.where(eq(sites.pubKey, publicKey))
.limit(1);
if (!site) {
logger.warn(`Site not found for public key: ${publicKey}`);
continue;
}
let online = site.online;

View File

@@ -0,0 +1,242 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { clients, newts, olms, Site, sites, clientSites } from "@server/db/schema";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
// Define Zod schema for request validation
const updateHolePunchSchema = z.object({
olmId: z.string().optional(),
newtId: z.string().optional(),
token: z.string(),
ip: z.string(),
port: z.number(),
timestamp: z.number()
});
// New response type with multi-peer destination support
interface PeerDestination {
destinationIP: string;
destinationPort: number;
}
export async function updateHolePunch(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
// Validate request parameters
const parsedParams = updateHolePunchSchema.safeParse(req.body);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { olmId, newtId, ip, port, timestamp, token } = parsedParams.data;
let currentSiteId: number | undefined;
let destinations: PeerDestination[] = [];
if (olmId) {
logger.debug(`Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}`);
const { session, olm: olmSession } =
await validateOlmSessionToken(token);
if (!session || !olmSession) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
if (olmId !== olmSession.olmId) {
logger.warn(`Olm ID mismatch: ${olmId} !== ${olmSession.olmId}`);
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.olmId, olmId));
if (!olm || !olm.clientId) {
logger.warn(`Olm not found: ${olmId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "Olm not found")
);
}
const [client] = await db
.update(clients)
.set({
endpoint: `${ip}:${port}`,
lastHolePunch: timestamp
})
.where(eq(clients.clientId, olm.clientId))
.returning();
if (!client) {
logger.warn(`Client not found for olm: ${olmId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "Client not found")
);
}
// Get all sites that this client is connected to
const clientSitePairs = await db
.select()
.from(clientSites)
.where(eq(clientSites.clientId, client.clientId));
if (clientSitePairs.length === 0) {
logger.warn(`No sites found for client: ${client.clientId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "No sites found for client")
);
}
// Get all sites details
const siteIds = clientSitePairs.map(pair => pair.siteId);
for (const siteId of siteIds) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId));
if (site && site.subnet && site.listenPort) {
destinations.push({
destinationIP: site.subnet.split("/")[0],
destinationPort: site.listenPort
});
}
}
} else if (newtId) {
const { session, newt: newtSession } =
await validateNewtSessionToken(token);
if (!session || !newtSession) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
if (newtId !== newtSession.newtId) {
logger.warn(`Newt ID mismatch: ${newtId} !== ${newtSession.newtId}`);
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.newtId, newtId));
if (!newt || !newt.siteId) {
logger.warn(`Newt not found: ${newtId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "New not found")
);
}
currentSiteId = newt.siteId;
// Update the current site with the new endpoint
const [updatedSite] = await db
.update(sites)
.set({
endpoint: `${ip}:${port}`,
lastHolePunch: timestamp
})
.where(eq(sites.siteId, newt.siteId))
.returning();
if (!updatedSite || !updatedSite.subnet) {
logger.warn(`Site not found: ${newt.siteId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "Site not found")
);
}
// Find all clients that connect to this site
const sitesClientPairs = await db
.select()
.from(clientSites)
.where(eq(clientSites.siteId, newt.siteId));
// Get client details for each client
for (const pair of sitesClientPairs) {
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, pair.clientId));
if (client && client.endpoint) {
const [host, portStr] = client.endpoint.split(':');
if (host && portStr) {
destinations.push({
destinationIP: host,
destinationPort: parseInt(portStr, 10)
});
}
}
}
// If this is a newt/site, also add other sites in the same org
// if (updatedSite.orgId) {
// const orgSites = await db
// .select()
// .from(sites)
// .where(eq(sites.orgId, updatedSite.orgId));
// for (const site of orgSites) {
// // Don't add the current site to the destinations
// if (site.siteId !== currentSiteId && site.subnet && site.endpoint && site.listenPort) {
// const [host, portStr] = site.endpoint.split(':');
// if (host && portStr) {
// destinations.push({
// destinationIP: host,
// destinationPort: site.listenPort
// });
// }
// }
// }
// }
}
// if (destinations.length === 0) {
// logger.warn(
// `No peer destinations found for olmId: ${olmId} or newtId: ${newtId}`
// );
// return next(createHttpError(HttpCode.NOT_FOUND, "No peer destinations found"));
// }
// Return the new multi-peer structure
return res.status(HttpCode.OK).send({
destinations: destinations
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred..."
)
);
}
}

View File

@@ -43,6 +43,8 @@ internalRouter.use("/gerbil", gerbilRouter);
gerbilRouter.post("/get-config", gerbil.getConfig);
gerbilRouter.post("/receive-bandwidth", gerbil.receiveBandwidth);
gerbilRouter.post("/update-hole-punch", gerbil.updateHolePunch);
gerbilRouter.post("/get-all-relays", gerbil.getAllRelays);
// Badger routes
const badgerRouter = Router();

View File

@@ -1,6 +1,14 @@
import { handleRegisterMessage } from "./newt";
import { handleNewtRegisterMessage, handleReceiveBandwidthMessage, handleGetConfigMessage } from "./newt";
import { handleOlmRegisterMessage, handleOlmRelayMessage, handleOlmPingMessage, startOfflineChecker } from "./olm";
import { MessageHandler } from "./ws";
export const messageHandlers: Record<string, MessageHandler> = {
"newt/wg/register": handleRegisterMessage,
};
"newt/wg/register": handleNewtRegisterMessage,
"olm/wg/register": handleOlmRegisterMessage,
"newt/wg/get-config": handleGetConfigMessage,
"newt/receive-bandwidth": handleReceiveBandwidthMessage,
"olm/wg/relay": handleOlmRelayMessage,
"olm/ping": handleOlmPingMessage
};
startOfflineChecker(); // this is to handle the offline check for olms

View File

@@ -24,7 +24,7 @@ export const newtGetTokenBodySchema = z.object({
export type NewtGetTokenBody = z.infer<typeof newtGetTokenBodySchema>;
export async function getToken(
export async function getNewtToken(
req: Request,
res: Response,
next: NextFunction

View File

@@ -0,0 +1,161 @@
import { z } from "zod";
import { MessageHandler } from "../ws";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import db from "@server/db";
import { clients, clientSites, Newt, sites } from "@server/db/schema";
import { eq } from "drizzle-orm";
import { updatePeer } from "../olm/peers";
const inputSchema = z.object({
publicKey: z.string(),
port: z.number().int().positive()
});
type Input = z.infer<typeof inputSchema>;
export const handleGetConfigMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
const now = new Date().getTime() / 1000;
logger.debug("Handling Newt get config message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
return;
}
const parsed = inputSchema.safeParse(message.data);
if (!parsed.success) {
logger.error(
"handleGetConfigMessage: Invalid input: " +
fromError(parsed.error).toString()
);
return;
}
const { publicKey, port } = message.data as Input;
const siteId = newt.siteId;
// Get the current site data
const [existingSite] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId));
if (!existingSite) {
logger.warn("handleGetConfigMessage: Site not found");
return;
}
// todo check if the public key has changed
// we need to wait for hole punch success
if (!existingSite.endpoint) {
logger.warn(`Site ${existingSite.siteId} has no endpoint, skipping`);
return;
}
if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 6) {
logger.warn(
`Site ${existingSite.siteId} last hole punch is too old, skipping`
);
return;
}
// update the endpoint and the public key
const [site] = await db
.update(sites)
.set({
publicKey,
listenPort: port
})
.where(eq(sites.siteId, siteId))
.returning();
if (!site) {
logger.error("handleGetConfigMessage: Failed to update site");
return;
}
// Get all clients connected to this site
const clientsRes = await db
.select()
.from(clients)
.innerJoin(clientSites, eq(clients.clientId, clientSites.clientId))
.where(eq(clientSites.siteId, siteId));
// Prepare peers data for the response
const peers = await Promise.all(
clientsRes
.filter((client) => {
if (!client.clients.pubKey) {
return false;
}
if (!client.clients.subnet) {
return false;
}
if (!client.clients.endpoint) {
return false;
}
if (!client.clients.online) {
return false;
}
return true;
})
.map(async (client) => {
// Add or update this peer on the olm if it is connected
try {
if (site.endpoint && site.publicKey) {
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort
});
}
} catch (error) {
logger.error(
`Failed to add/update peer ${client.clients.pubKey} to newt ${newt.newtId}: ${error}`
);
}
return {
publicKey: client.clients.pubKey!,
allowedIps: [client.clients.subnet!],
endpoint: client.clientSites.isRelayed
? ""
: client.clients.endpoint! // if its relayed it should be localhost
};
})
);
// Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null);
// Build the configuration response
const configResponse = {
ipAddress: site.address,
peers: validPeers
};
logger.debug("Sending config: ", configResponse);
return {
message: {
type: "newt/wg/receive-config",
data: {
...configResponse
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -2,6 +2,7 @@ import db from "@server/db";
import { MessageHandler } from "../ws";
import {
exitNodes,
Newt,
resources,
sites,
Target,
@@ -11,10 +12,11 @@ import { eq, and, sql, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger";
export const handleRegisterMessage: MessageHandler = async (context) => {
const { message, newt, sendToClient } = context;
export const handleNewtRegisterMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
logger.info("Handling register message!");
logger.info("Handling register newt message!");
if (!newt) {
logger.warn("Newt not found");

View File

@@ -0,0 +1,52 @@
import db from "@server/db";
import { MessageHandler } from "../ws";
import { clients, Newt } from "@server/db/schema";
import { eq } from "drizzle-orm";
import logger from "@server/logger";
interface PeerBandwidth {
publicKey: string;
bytesIn: number;
bytesOut: number;
}
export const handleReceiveBandwidthMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
if (!message.data.bandwidthData) {
logger.warn("No bandwidth data provided");
}
const bandwidthData: PeerBandwidth[] = message.data.bandwidthData;
if (!Array.isArray(bandwidthData)) {
throw new Error("Invalid bandwidth data");
}
await db.transaction(async (trx) => {
for (const peer of bandwidthData) {
const { publicKey, bytesIn, bytesOut } = peer;
// Find the client by public key
const [client] = await trx
.select()
.from(clients)
.where(eq(clients.pubKey, publicKey))
.limit(1);
if (!client) {
continue;
}
// Update the client's bandwidth usage
await trx
.update(clients)
.set({
megabytesOut: (client.megabytesIn || 0) + bytesIn,
megabytesIn: (client.megabytesOut || 0) + bytesOut,
lastBandwidthUpdate: new Date().toISOString(),
})
.where(eq(clients.clientId, client.clientId));
}
});
};

View File

@@ -1,3 +1,5 @@
export * from "./createNewt";
export * from "./getToken";
export * from "./handleRegisterMessage";
export * from "./getNewtToken";
export * from "./handleNewtRegisterMessage";
export * from "./handleReceiveBandwidthMessage";
export * from "./handleGetConfigMessage";

View File

@@ -0,0 +1,78 @@
import db from '@server/db';
import { newts, sites } from '@server/db/schema';
import { eq } from 'drizzle-orm';
import { sendToClient } from '../ws';
import logger from '@server/logger';
export async function addPeer(siteId: number, peer: {
publicKey: string;
allowedIps: string[];
endpoint: string;
}) {
const [site] = await db.select().from(sites).where(eq(sites.siteId, siteId)).limit(1);
if (!site) {
throw new Error(`Exit node with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db.select().from(newts).where(eq(newts.siteId, siteId)).limit(1);
if (!newt) {
throw new Error(`Site found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: 'newt/wg/peer/add',
data: peer
});
logger.info(`Added peer ${peer.publicKey} to newt ${newt.newtId}`);
}
export async function deletePeer(siteId: number, publicKey: string) {
const [site] = await db.select().from(sites).where(eq(sites.siteId, siteId)).limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db.select().from(newts).where(eq(newts.siteId, siteId)).limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: 'newt/wg/peer/remove',
data: {
publicKey
}
});
logger.info(`Deleted peer ${publicKey} from newt ${newt.newtId}`);
}
export async function updatePeer(siteId: number, publicKey: string, peer: {
allowedIps?: string[];
endpoint?: string;
}) {
const [site] = await db.select().from(sites).where(eq(sites.siteId, siteId)).limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db.select().from(newts).where(eq(newts.siteId, siteId)).limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: 'newt/wg/peer/update',
data: {
publicKey,
...peer
}
});
logger.info(`Updated peer ${publicKey} on newt ${newt.newtId}`);
}

View File

@@ -0,0 +1,106 @@
import { NextFunction, Request, Response } from "express";
import db from "@server/db";
import { hash } from "@node-rs/argon2";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { newts } from "@server/db/schema";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import { SqliteError } from "better-sqlite3";
import moment from "moment";
import { generateSessionToken } from "@server/auth/sessions/app";
import { createNewtSession } from "@server/auth/sessions/newt";
import { fromError } from "zod-validation-error";
import { hashPassword } from "@server/auth/password";
export const createNewtBodySchema = z.object({});
export type CreateNewtBody = z.infer<typeof createNewtBodySchema>;
export type CreateNewtResponse = {
token: string;
newtId: string;
secret: string;
};
const createNewtSchema = z
.object({
newtId: z.string(),
secret: z.string()
})
.strict();
export async function createNewt(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = createNewtSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { newtId, secret } = parsedBody.data;
if (!req.userOrgRoleId) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);
}
const secretHash = await hashPassword(secret);
await db.insert(newts).values({
newtId: newtId,
secretHash,
dateCreated: moment().toISOString(),
});
// give the newt their default permissions:
// await db.insert(newtActions).values({
// newtId: newtId,
// actionId: ActionsEnum.createOrg,
// orgId: null,
// });
const token = generateSessionToken();
await createNewtSession(token, newtId);
return response<CreateNewtResponse>(res, {
data: {
newtId,
secret,
token,
},
success: true,
error: false,
message: "Newt created successfully",
status: HttpCode.OK,
});
} catch (e) {
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"A newt with that email address already exists"
)
);
} else {
console.error(e);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create newt"
)
);
}
}
}

View File

@@ -0,0 +1,119 @@
import { generateSessionToken } from "@server/auth/sessions/app";
import db from "@server/db";
import { olms } from "@server/db/schema";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
import { eq } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import {
createOlmSession,
validateOlmSessionToken
} from "@server/auth/sessions/olm";
import { verifyPassword } from "@server/auth/password";
import logger from "@server/logger";
import config from "@server/lib/config";
export const olmGetTokenBodySchema = z.object({
olmId: z.string(),
secret: z.string(),
token: z.string().optional()
});
export type OlmGetTokenBody = z.infer<typeof olmGetTokenBodySchema>;
export async function getOlmToken(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
const parsedBody = olmGetTokenBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { olmId, secret, token } = parsedBody.data;
try {
if (token) {
const { session, olm } = await validateOlmSessionToken(token);
if (session) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
`Olm session already valid. Olm ID: ${olmId}. IP: ${req.ip}.`
);
}
return response<null>(res, {
data: null,
success: true,
error: false,
message: "Token session already valid",
status: HttpCode.OK
});
}
}
const existingOlmRes = await db
.select()
.from(olms)
.where(eq(olms.olmId, olmId));
if (!existingOlmRes || !existingOlmRes.length) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No olm found with that olmId"
)
);
}
const existingOlm = existingOlmRes[0];
const validSecret = await verifyPassword(
secret,
existingOlm.secretHash
);
if (!validSecret) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
`Olm id or secret is incorrect. Olm: ID ${olmId}. IP: ${req.ip}.`
);
}
return next(
createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect")
);
}
logger.debug("Creating new olm session token");
const resToken = generateSessionToken();
await createOlmSession(resToken, existingOlm.olmId);
logger.debug("Token created successfully");
return response<{ token: string }>(res, {
data: {
token: resToken
},
success: true,
error: false,
message: "Token created successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to authenticate olm"
)
);
}
}

View File

@@ -0,0 +1,93 @@
import db from "@server/db";
import { MessageHandler } from "../ws";
import { clients, Olm } from "@server/db/schema";
import { eq, lt, isNull } from "drizzle-orm";
import logger from "@server/logger";
// Track if the offline checker interval is running
let offlineCheckerInterval: NodeJS.Timeout | null = null;
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
/**
* Starts the background interval that checks for clients that haven't pinged recently
* and marks them as offline
*/
export const startOfflineChecker = (): void => {
if (offlineCheckerInterval) {
return; // Already running
}
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = new Date(Date.now() - OFFLINE_THRESHOLD_MS);
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
await db
.update(clients)
.set({ online: false })
.where(
eq(clients.online, true) &&
(lt(clients.lastPing, twoMinutesAgo.toISOString()) || isNull(clients.lastPing))
);
} catch (error) {
logger.error("Error in offline checker interval", { error });
}
}, OFFLINE_CHECK_INTERVAL);
logger.info("Started offline checker interval");
}
/**
* Stops the background interval that checks for offline clients
*/
export const stopOfflineChecker = (): void => {
if (offlineCheckerInterval) {
clearInterval(offlineCheckerInterval);
offlineCheckerInterval = null;
logger.info("Stopped offline checker interval");
}
}
/**
* Handles ping messages from clients and responds with pong
*/
export const handleOlmPingMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
try {
// Update the client's last ping timestamp
await db
.update(clients)
.set({
lastPing: new Date().toISOString(),
online: true,
})
.where(eq(clients.clientId, olm.clientId));
} catch (error) {
logger.error("Error handling ping message", { error });
}
return {
message: {
type: "pong",
data: {
timestamp: new Date().toISOString(),
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -0,0 +1,181 @@
import db from "@server/db";
import { MessageHandler } from "../ws";
import {
clients,
clientSites,
exitNodes,
Olm,
olms,
sites
} from "@server/db/schema";
import { eq, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
const now = new Date().getTime() / 1000;
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
const clientId = olm.clientId;
const { publicKey } = message.data;
if (!publicKey) {
logger.warn("Public key not provided");
return;
}
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
if (client.exitNodeId) {
// Get the exit node for this site
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, client.exitNodeId))
.limit(1);
// Send holepunch message for each site
sendToClient(olm.olmId, {
type: "olm/wg/holepunch",
data: {
serverPubKey: exitNode.publicKey
}
});
}
if (now - (client.lastHolePunch || 0) > 6) {
logger.warn("Client last hole punch is too old, skipping all sites");
return;
}
if (client.pubKey !== publicKey) {
logger.info(
"Public key mismatch. Updating public key and clearing session info..."
);
// Update the client's public key
await db
.update(clients)
.set({
pubKey: publicKey
})
.where(eq(clients.clientId, olm.clientId));
// set isRelay to false for all of the client's sites to reset the connection metadata
await db
.update(clientSites)
.set({
isRelayed: false
})
.where(eq(clientSites.clientId, olm.clientId));
}
// Get all sites data
const sitesData = await db
.select()
.from(sites)
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
.where(eq(clientSites.clientId, client.clientId));
// Prepare an array to store site configurations
const siteConfigurations = [];
// Process each site
for (const { sites: site } of sitesData) {
if (!site.exitNodeId) {
logger.warn(
`Site ${site.siteId} does not have exit node, skipping`
);
continue;
}
// Validate endpoint and hole punch status
if (!site.endpoint) {
logger.warn(`Site ${site.siteId} has no endpoint, skipping`);
continue;
}
if (site.lastHolePunch && now - site.lastHolePunch > 6) {
logger.warn(
`Site ${site.siteId} last hole punch is too old, skipping`
);
continue;
}
// If public key changed, delete old peer from this site
if (client.pubKey && client.pubKey != publicKey) {
logger.info(
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
);
await deletePeer(site.siteId, client.pubKey!);
}
if (!site.subnet) {
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
continue;
}
// Add the peer to the exit node for this site
if (client.endpoint) {
logger.info(
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${client.endpoint}`
);
await addPeer(site.siteId, {
publicKey: publicKey,
allowedIps: [client.subnet],
endpoint: client.endpoint
});
} else {
logger.warn(
`Client ${client.clientId} has no endpoint, skipping peer addition`
);
}
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort
});
}
// If we have no valid site configurations, don't send a connect message
if (siteConfigurations.length === 0) {
logger.warn("No valid site configurations found");
return;
}
// Return connect message with all site configurations
return {
message: {
type: "olm/wg/connect",
data: {
sites: siteConfigurations,
tunnelIP: client.subnet
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -0,0 +1,58 @@
import db from "@server/db";
import { MessageHandler } from "../ws";
import { clients, clientSites, Olm } from "@server/db/schema";
import { eq } from "drizzle-orm";
import { updatePeer } from "../newt/peers";
import logger from "@server/logger";
export const handleOlmRelayMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
logger.info("Handling relay olm message!");
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
return;
}
const clientId = olm.clientId;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Site not found or does not have exit node");
return;
}
// make sure we hand endpoints for both the site and the client and the lastHolePunch is not too old
if (!client.pubKey) {
logger.warn("Site or client has no endpoint or listen port");
return;
}
const { siteId } = message.data;
await db
.update(clientSites)
.set({
isRelayed: true
})
.where(eq(clientSites.clientId, olm.clientId));
// update the peer on the exit node
await updatePeer(siteId, client.pubKey, {
endpoint: "" // this removes the endpoint
});
return;
};

View File

@@ -0,0 +1,5 @@
export * from "./handleOlmRegisterMessage";
export * from "./getOlmToken";
export * from "./createOlm";
export * from "./handleOlmRelayMessage";
export * from "./handleOlmPingMessage";

View File

@@ -0,0 +1,73 @@
import db from '@server/db';
import { clients, olms, newts } from '@server/db/schema';
import { eq } from 'drizzle-orm';
import { sendToClient } from '../ws';
import logger from '@server/logger';
export async function addPeer(clientId: number, peer: {
siteId: number,
publicKey: string;
endpoint: string;
serverIP: string | null;
serverPort: number | null;
}) {
const [olm] = await db.select().from(olms).where(eq(olms.clientId, clientId)).limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
sendToClient(olm.olmId, {
type: 'olm/wg/peer/add',
data: {
siteId: peer.siteId,
publicKey: peer.publicKey,
endpoint: peer.endpoint,
serverIP: peer.serverIP,
serverPort: peer.serverPort
}
});
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`);
}
export async function deletePeer(clientId: number, publicKey: string) {
const [olm] = await db.select().from(olms).where(eq(olms.clientId, clientId)).limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
sendToClient(olm.olmId, {
type: 'olm/wg/peer/remove',
data: {
publicKey
}
});
logger.info(`Deleted peer ${publicKey} from olm ${olm.olmId}`);
}
export async function updatePeer(clientId: number, peer: {
siteId: number,
publicKey: string;
endpoint: string;
serverIP: string | null;
serverPort: number | null;
}) {
const [olm] = await db.select().from(olms).where(eq(olms.clientId, clientId)).limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
sendToClient(olm.olmId, {
type: 'olm/wg/peer/update',
data: {
siteId: peer.siteId,
publicKey: peer.publicKey,
endpoint: peer.endpoint,
serverIP: peer.serverIP,
serverPort: peer.serverPort
}
});
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`);
}

View File

@@ -20,11 +20,13 @@ import config from "@server/lib/config";
import { fromError } from "zod-validation-error";
import { defaultRoleAllowedActions } from "../role";
import { OpenAPITags, registry } from "@server/openApi";
import { isValidCIDR } from "@server/lib/validators";
const createOrgSchema = z
.object({
orgId: z.string(),
name: z.string().min(1).max(255)
name: z.string().min(1).max(255),
subnet: z.string()
})
.strict();
@@ -85,7 +87,32 @@ export async function createOrg(
// );
// }
const { orgId, name } = parsedBody.data;
const { orgId, name, subnet } = parsedBody.data;
if (!isValidCIDR(subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid subnet format. Please provide a valid CIDR notation."
)
);
}
// make sure the subnet is unique
const subnetExists = await db
.select()
.from(orgs)
.where(eq(orgs.subnet, subnet))
.limit(1);
if (subnetExists.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
// make sure the orgId is unique
const orgExists = await db
@@ -116,7 +143,8 @@ export async function createOrg(
.insert(orgs)
.values({
orgId,
name
name,
subnet,
})
.returning();

View File

@@ -6,3 +6,4 @@ export * from "./listUserOrgs";
export * from "./checkId";
export * from "./getOrgOverview";
export * from "./listOrgs";
export * from "./pickOrgDefaults";

View File

@@ -0,0 +1,35 @@
import { Request, Response, NextFunction } from "express";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { getNextAvailableOrgSubnet } from "@server/lib/ip";
export type PickOrgDefaultsResponse = {
subnet: string;
};
export async function pickOrgDefaults(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const subnet = await getNextAvailableOrgSubnet();
return response<PickOrgDefaultsResponse>(res, {
data: {
subnet: subnet
},
success: true,
error: false,
message: "Organization defaults created successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { roles, userSites, sites, roleSites, Site } from "@server/db/schemas";
import { roles, userSites, sites, roleSites, Site, orgs } from "@server/db/schemas";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -10,11 +10,12 @@ import { eq, and } from "drizzle-orm";
import { getUniqueSiteName } from "@server/db/names";
import { addPeer } from "../gerbil/peers";
import { fromError } from "zod-validation-error";
import { hash } from "@node-rs/argon2";
import { newts } from "@server/db/schemas";
import moment from "moment";
import { OpenAPITags, registry } from "@server/openApi";
import { hashPassword } from "@server/auth/password";
import { isValidIP } from "@server/lib/validators";
import { isIpInCidr } from "@server/lib/ip";
const createSiteParamsSchema = z
.object({
@@ -36,6 +37,7 @@ const createSiteSchema = z
subnet: z.string().optional(),
newtId: z.string().optional(),
secret: z.string().optional(),
address: z.string().optional(),
type: z.enum(["newt", "wireguard", "local"])
})
.strict();
@@ -78,8 +80,16 @@ export async function createSite(
);
}
const { name, type, exitNodeId, pubKey, subnet, newtId, secret } =
parsedBody.data;
const {
name,
type,
exitNodeId,
pubKey,
subnet,
newtId,
secret,
address
} = parsedBody.data;
const parsedParams = createSiteParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
@@ -99,6 +109,70 @@ export async function createSite(
);
}
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Organization with ID ${orgId} not found`
)
);
}
let updatedAddress = null;
if (address) {
if (!isValidIP(address)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid subnet format. Please provide a valid CIDR notation."
)
);
}
if (!isIpInCidr(address, org.subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"IP is not in the CIDR range of the subnet."
)
);
}
updatedAddress = `${address}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
// make sure the subnet is unique
const addressExistsSites = await db
.select()
.from(sites)
.where(eq(sites.address, updatedAddress))
.limit(1);
if (addressExistsSites.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
const addressExistsClients = await db
.select()
.from(sites)
.where(eq(sites.subnet, updatedAddress))
.limit(1);
if (addressExistsClients.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
}
const niceId = await getUniqueSiteName(orgId);
await db.transaction(async (trx) => {
@@ -122,6 +196,7 @@ export async function createSite(
exitNodeId,
name,
niceId,
address: updatedAddress || null,
subnet,
type,
...(pubKey && type == "wireguard" && { pubKey })
@@ -136,6 +211,7 @@ export async function createSite(
orgId,
name,
niceId,
address: updatedAddress || null,
type,
subnet: "0.0.0.0/0"
})

View File

@@ -43,7 +43,8 @@ function querySites(orgId: string, accessibleSiteIds: number[]) {
megabytesOut: sites.megabytesOut,
orgName: orgs.name,
type: sites.type,
online: sites.online
online: sites.online,
address: sites.address,
})
.from(sites)
.leftJoin(orgs, eq(sites.orgId, orgs.orgId))

View File

@@ -6,10 +6,11 @@ import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { findNextAvailableCidr } from "@server/lib/ip";
import { findNextAvailableCidr, getNextAvailableClientSubnet } from "@server/lib/ip";
import { generateId } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import { OpenAPITags, registry } from "@server/openApi";
import { fromError } from "zod-validation-error";
import { z } from "zod";
export type PickSiteDefaultsResponse = {
@@ -22,6 +23,7 @@ export type PickSiteDefaultsResponse = {
subnet: string;
newtId: string;
newtSecret: string;
clientAddress: string;
};
registry.registerPath({
@@ -38,12 +40,29 @@ registry.registerPath({
responses: {}
});
const pickSiteDefaultsSchema = z
.object({
orgId: z.string()
})
.strict();
export async function pickSiteDefaults(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = pickSiteDefaultsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
// TODO: more intelligent way to pick the exit node
// make sure there is an exit node by counting the exit nodes table
@@ -89,6 +108,18 @@ export async function pickSiteDefaults(
);
}
const newClientAddress = await getNextAvailableClientSubnet(orgId);
if (!newClientAddress) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"No available subnet found"
)
);
}
const clientAddress = newClientAddress.split("/")[0];
const newtId = generateId(15);
const secret = generateId(48);
@@ -100,7 +131,9 @@ export async function pickSiteDefaults(
name: exitNode.name,
listenPort: exitNode.listenPort,
endpoint: exitNode.endpoint,
// subnet: `${newSubnet.split("/")[0]}/${config.getRawConfig().gerbil.block_size}`, // we want the block size of the whole subnet
subnet: newSubnet,
clientAddress: clientAddress,
newtId,
newtSecret: secret
},

View File

@@ -3,10 +3,11 @@ import { Server as HttpServer } from "http";
import { WebSocket, WebSocketServer } from "ws";
import { IncomingMessage } from "http";
import { Socket } from "net";
import { Newt, newts, NewtSession } from "@server/db/schemas";
import { Newt, newts, NewtSession, olms, Olm, OlmSession } from "@server/db/schemas";
import { eq } from "drizzle-orm";
import db from "@server/db";
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
import { messageHandlers } from "./messageHandlers";
import logger from "@server/logger";
@@ -15,13 +16,17 @@ interface WebSocketRequest extends IncomingMessage {
token?: string;
}
type ClientType = 'newt' | 'olm';
interface AuthenticatedWebSocket extends WebSocket {
newt?: Newt;
client?: Newt | Olm;
clientType?: ClientType;
}
interface TokenPayload {
newt: Newt;
session: NewtSession;
client: Newt | Olm;
session: NewtSession | OlmSession;
clientType: ClientType;
}
interface WSMessage {
@@ -33,15 +38,16 @@ interface HandlerResponse {
message: WSMessage;
broadcast?: boolean;
excludeSender?: boolean;
targetNewtId?: string;
targetClientId?: string;
}
interface HandlerContext {
message: WSMessage;
senderWs: WebSocket;
newt: Newt | undefined;
sendToClient: (newtId: string, message: WSMessage) => boolean;
broadcastToAllExcept: (message: WSMessage, excludeNewtId?: string) => void;
client: Newt | Olm | undefined;
clientType: ClientType;
sendToClient: (clientId: string, message: WSMessage) => boolean;
broadcastToAllExcept: (message: WSMessage, excludeClientId?: string) => void;
connectedClients: Map<string, WebSocket[]>;
}
@@ -54,34 +60,32 @@ const wss: WebSocketServer = new WebSocketServer({ noServer: true });
let connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
// Helper functions for client management
const addClient = (newtId: string, ws: AuthenticatedWebSocket): void => {
const existingClients = connectedClients.get(newtId) || [];
const addClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => {
const existingClients = connectedClients.get(clientId) || [];
existingClients.push(ws);
connectedClients.set(newtId, existingClients);
logger.info(`Client added to tracking - Newt ID: ${newtId}, Total connections: ${existingClients.length}`);
connectedClients.set(clientId, existingClients);
logger.info(`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Total connections: ${existingClients.length}`);
};
const removeClient = (newtId: string, ws: AuthenticatedWebSocket): void => {
const existingClients = connectedClients.get(newtId) || [];
const removeClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => {
const existingClients = connectedClients.get(clientId) || [];
const updatedClients = existingClients.filter(client => client !== ws);
if (updatedClients.length === 0) {
connectedClients.delete(newtId);
logger.info(`All connections removed for Newt ID: ${newtId}`);
connectedClients.delete(clientId);
logger.info(`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`);
} else {
connectedClients.set(newtId, updatedClients);
logger.info(`Connection removed - Newt ID: ${newtId}, Remaining connections: ${updatedClients.length}`);
connectedClients.set(clientId, updatedClients);
logger.info(`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`);
}
};
// Helper functions for sending messages
const sendToClient = (newtId: string, message: WSMessage): boolean => {
const clients = connectedClients.get(newtId);
const sendToClient = (clientId: string, message: WSMessage): boolean => {
const clients = connectedClients.get(clientId);
if (!clients || clients.length === 0) {
logger.info(`No active connections found for Newt ID: ${newtId}`);
logger.info(`No active connections found for Client ID: ${clientId}`);
return false;
}
const messageString = JSON.stringify(message);
clients.forEach(client => {
if (client.readyState === WebSocket.OPEN) {
@@ -91,9 +95,9 @@ const sendToClient = (newtId: string, message: WSMessage): boolean => {
return true;
};
const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void => {
connectedClients.forEach((clients, newtId) => {
if (newtId !== excludeNewtId) {
const broadcastToAllExcept = (message: WSMessage, excludeClientId?: string): void => {
connectedClients.forEach((clients, clientId) => {
if (clientId !== excludeClientId) {
clients.forEach(client => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(message));
@@ -103,84 +107,88 @@ const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void
});
};
// Token verification middleware (unchanged)
const verifyToken = async (token: string): Promise<TokenPayload | null> => {
// Token verification middleware
const verifyToken = async (token: string, clientType: ClientType): Promise<TokenPayload | null> => {
try {
const { session, newt } = await validateNewtSessionToken(token);
if (!session || !newt) {
return null;
if (clientType === 'newt') {
const { session, newt } = await validateNewtSessionToken(token);
if (!session || !newt) {
return null;
}
const existingNewt = await db
.select()
.from(newts)
.where(eq(newts.newtId, newt.newtId));
if (!existingNewt || !existingNewt[0]) {
return null;
}
return { client: existingNewt[0], session, clientType };
} else {
const { session, olm } = await validateOlmSessionToken(token);
if (!session || !olm) {
return null;
}
const existingOlm = await db
.select()
.from(olms)
.where(eq(olms.olmId, olm.olmId));
if (!existingOlm || !existingOlm[0]) {
return null;
}
return { client: existingOlm[0], session, clientType };
}
const existingNewt = await db
.select()
.from(newts)
.where(eq(newts.newtId, newt.newtId));
if (!existingNewt || !existingNewt[0]) {
return null;
}
return { newt: existingNewt[0], session };
} catch (error) {
logger.error("Token verification failed:", error);
return null;
}
};
const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => {
const setupConnection = (ws: AuthenticatedWebSocket, client: Newt | Olm, clientType: ClientType): void => {
logger.info("Establishing websocket connection");
if (!newt) {
logger.error("Connection attempt without newt");
if (!client) {
logger.error("Connection attempt without client");
return ws.terminate();
}
ws.newt = newt;
ws.client = client;
ws.clientType = clientType;
// Add client to tracking
addClient(newt.newtId, ws);
const clientId = clientType === 'newt' ? (client as Newt).newtId : (client as Olm).olmId;
addClient(clientId, ws, clientType);
ws.on("message", async (data) => {
try {
const message: WSMessage = JSON.parse(data.toString());
// logger.info(`Message received from Newt ID ${newtId}:`, message);
// Validate message format
if (!message.type || typeof message.type !== "string") {
throw new Error("Invalid message format: missing or invalid type");
}
// Get the appropriate handler for the message type
const handler = messageHandlers[message.type];
if (!handler) {
throw new Error(`Unsupported message type: ${message.type}`);
}
// Process the message and get response
const response = await handler({
message,
senderWs: ws,
newt: ws.newt,
client: ws.client,
clientType: ws.clientType!,
sendToClient,
broadcastToAllExcept,
connectedClients
});
// Send response if one was returned
if (response) {
if (response.broadcast) {
// Broadcast to all clients except sender if specified
broadcastToAllExcept(response.message, response.excludeSender ? newt.newtId : undefined);
} else if (response.targetNewtId) {
// Send to specific client if targetNewtId is provided
sendToClient(response.targetNewtId, response.message);
broadcastToAllExcept(response.message, response.excludeSender ? clientId : undefined);
} else if (response.targetClientId) {
sendToClient(response.targetClientId, response.message);
} else {
// Send back to sender
ws.send(JSON.stringify(response.message));
}
}
} catch (error) {
logger.error("Message handling error:", error);
ws.send(JSON.stringify({
@@ -194,18 +202,18 @@ const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => {
});
ws.on("close", () => {
removeClient(newt.newtId, ws);
logger.info(`Client disconnected - Newt ID: ${newt.newtId}`);
removeClient(clientId, ws, clientType);
logger.info(`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`);
});
ws.on("error", (error: Error) => {
logger.error(`WebSocket error for Newt ID ${newt.newtId}:`, error);
logger.error(`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`, error);
});
logger.info(`WebSocket connection established - Newt ID: ${newt.newtId}`);
logger.info(`WebSocket connection established - ${clientType.toUpperCase()} ID: ${clientId}`);
};
// Router endpoint (unchanged)
// Router endpoint
router.get("/ws", (req: Request, res: Response) => {
res.status(200).send("WebSocket endpoint");
});
@@ -214,18 +222,22 @@ router.get("/ws", (req: Request, res: Response) => {
const handleWSUpgrade = (server: HttpServer): void => {
server.on("upgrade", async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
try {
const token = request.url?.includes("?")
? new URLSearchParams(request.url.split("?")[1]).get("token") || ""
: request.headers["sec-websocket-protocol"];
const url = new URL(request.url || '', `http://${request.headers.host}`);
const token = url.searchParams.get('token') || request.headers["sec-websocket-protocol"] || '';
let clientType = url.searchParams.get('clientType') as ClientType;
if (!token) {
logger.warn("Unauthorized connection attempt: no token...");
if (!clientType) {
clientType = "newt";
}
if (!token || !clientType || !['newt', 'olm'].includes(clientType)) {
logger.warn("Unauthorized connection attempt: invalid token or client type...");
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
socket.destroy();
return;
}
const tokenPayload = await verifyToken(token);
const tokenPayload = await verifyToken(token, clientType);
if (!tokenPayload) {
logger.warn("Unauthorized connection attempt: invalid token...");
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
@@ -234,7 +246,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
}
wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => {
setupConnection(ws, tokenPayload.newt);
setupConnection(ws, tokenPayload.client, tokenPayload.clientType);
});
} catch (error) {
logger.error("WebSocket upgrade error:", error);