Merge branch 'dev' of https://github.com/fosrl/pangolin into dev

This commit is contained in:
miloschwartz
2025-07-14 22:21:04 -07:00
234 changed files with 16088 additions and 7588 deletions

View File

@@ -112,7 +112,11 @@ export async function requestTotpSecret(
const hex = crypto.getRandomValues(new Uint8Array(20));
const secret = encodeHex(hex);
const uri = createTOTPKeyURI("Pangolin", user.email!, hex);
const uri = createTOTPKeyURI(
"Pangolin",
user.email!,
hex
);
await db
.update(users)

View File

@@ -1,8 +1,7 @@
import { NextFunction, Request, Response } from "express";
import { db } from "@server/db";
import { db, users } from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { users } from "@server/db";
import { fromError } from "zod-validation-error";
import createHttpError from "http-errors";
import response from "@server/lib/response";
@@ -57,8 +56,6 @@ export async function signup(
const { email, password, inviteToken, inviteId } = parsedBody.data;
logger.debug("signup", { email, password, inviteToken, inviteId });
const passwordHash = await hashPassword(password);
const userId = generateId(15);
@@ -143,15 +140,21 @@ export async function signup(
if (diff < 2) {
// If the user was created less than 2 hours ago, we don't want to create a new user
return response<SignUpResponse>(res, {
data: {
emailVerificationRequired: true
},
success: true,
error: false,
message: `A user with that email address already exists. We sent an email to ${email} with a verification code.`,
status: HttpCode.OK
});
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"A user with that email address already exists"
)
);
// return response<SignUpResponse>(res, {
// data: {
// emailVerificationRequired: true
// },
// success: true,
// error: false,
// message: `A user with that email address already exists. We sent an email to ${email} with a verification code.`,
// status: HttpCode.OK
// });
} else {
// If the user was created more than 2 hours ago, we want to delete the old user and create a new one
await db.delete(users).where(eq(users.userId, user.userId));

View File

@@ -4,7 +4,7 @@ import { z } from "zod";
import { fromError } from "zod-validation-error";
import HttpCode from "@server/types/HttpCode";
import { response } from "@server/lib";
import { db } from "@server/db";
import { db, userOrgs } from "@server/db";
import { User, emailVerificationCodes, users } from "@server/db";
import { eq } from "drizzle-orm";
import { isWithinExpirationDate } from "oslo";

View File

@@ -0,0 +1,252 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import {
roles,
Client,
clients,
roleClients,
userClients,
olms,
clientSites,
exitNodes,
orgs,
sites
} from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import moment from "moment";
import { hashPassword } from "@server/auth/password";
import { isValidCIDR, isValidIP } from "@server/lib/validators";
import { isIpInCidr } from "@server/lib/ip";
import { OpenAPITags, registry } from "@server/openApi";
const createClientParamsSchema = z
.object({
orgId: z.string()
})
.strict();
const createClientSchema = z
.object({
name: z.string().min(1).max(255),
siteIds: z.array(z.number().int().positive()),
olmId: z.string(),
secret: z.string(),
subnet: z.string(),
type: z.enum(["olm"])
})
.strict();
export type CreateClientBody = z.infer<typeof createClientSchema>;
export type CreateClientResponse = Client;
registry.registerPath({
method: "put",
path: "/org/{orgId}/client",
description: "Create a new client.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
params: createClientParamsSchema,
body: {
content: {
"application/json": {
schema: createClientSchema
}
}
}
},
responses: {}
});
export async function createClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = createClientSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { name, type, siteIds, olmId, secret, subnet } = parsedBody.data;
const parsedParams = createClientParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
if (req.user && !req.userOrgRoleId) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);
}
if (!isValidIP(subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid subnet format. Please provide a valid CIDR notation."
)
);
}
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Organization with ID ${orgId} not found`
)
);
}
if (!isIpInCidr(subnet, org.subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"IP is not in the CIDR range of the subnet."
)
);
}
const updatedSubnet = `${subnet}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
// make sure the subnet is unique
const subnetExistsClients = await db
.select()
.from(clients)
.where(eq(clients.subnet, updatedSubnet))
.limit(1);
if (subnetExistsClients.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
const subnetExistsSites = await db
.select()
.from(sites)
.where(eq(sites.address, updatedSubnet))
.limit(1);
if (subnetExistsSites.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
await db.transaction(async (trx) => {
// TODO: more intelligent way to pick the exit node
// make sure there is an exit node by counting the exit nodes table
const nodes = await db.select().from(exitNodes);
if (nodes.length === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"No exit nodes available"
)
);
}
// get the first exit node
const exitNode = nodes[0];
const adminRole = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (adminRole.length === 0) {
trx.rollback();
return next(
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
);
}
const [newClient] = await trx
.insert(clients)
.values({
exitNodeId: exitNode.exitNodeId,
orgId,
name,
subnet: updatedSubnet,
type
})
.returning();
await trx.insert(roleClients).values({
roleId: adminRole[0].roleId,
clientId: newClient.clientId
});
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the site
trx.insert(userClients).values({
userId: req.user?.userId!,
clientId: newClient.clientId
});
}
// Create site to client associations
if (siteIds && siteIds.length > 0) {
await trx.insert(clientSites).values(
siteIds.map((siteId) => ({
clientId: newClient.clientId,
siteId
}))
);
}
const secretHash = await hashPassword(secret);
await trx.insert(olms).values({
olmId,
secretHash,
clientId: newClient.clientId,
dateCreated: moment().toISOString()
});
return response<CreateClientResponse>(res, {
data: newClient,
success: true,
error: false,
message: "Site created successfully",
status: HttpCode.CREATED
});
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,88 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients, clientSites } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const deleteClientSchema = z
.object({
clientId: z.string().transform(Number).pipe(z.number().int().positive())
})
.strict();
registry.registerPath({
method: "delete",
path: "/client/{clientId}",
description: "Delete a client by its client ID.",
tags: [OpenAPITags.Client],
request: {
params: deleteClientSchema
},
responses: {}
});
export async function deleteClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = deleteClientSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
await db.transaction(async (trx) => {
// Delete the client-site associations first
await trx
.delete(clientSites)
.where(eq(clientSites.clientId, clientId));
// Then delete the client itself
await trx
.delete(clients)
.where(eq(clients.clientId, clientId));
});
return response(res, {
data: null,
success: true,
error: false,
message: "Client deleted successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,101 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients, clientSites } from "@server/db";
import { eq, and } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import stoi from "@server/lib/stoi";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const getClientSchema = z
.object({
clientId: z.string().transform(stoi).pipe(z.number().int().positive()),
orgId: z.string().optional()
})
.strict();
async function query(clientId: number) {
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return null;
}
// Get the siteIds associated with this client
const sites = await db
.select({ siteId: clientSites.siteId })
.from(clientSites)
.where(eq(clientSites.clientId, clientId));
// Add the siteIds to the client object
return {
...client,
siteIds: sites.map(site => site.siteId)
};
}
export type GetClientResponse = NonNullable<Awaited<ReturnType<typeof query>>>;
registry.registerPath({
method: "get",
path: "/org/{orgId}/client/{clientId}",
description: "Get a client by its client ID.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
params: getClientSchema
},
responses: {}
});
export async function getClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = getClientSchema.safeParse(req.params);
if (!parsedParams.success) {
logger.error(
`Error parsing params: ${fromError(parsedParams.error).toString()}`
);
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
const client = await query(clientId);
if (!client) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Client not found")
);
}
return response<GetClientResponse>(res, {
data: client,
success: true,
error: false,
message: "Client retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,6 @@
export * from "./pickClientDefaults";
export * from "./createClient";
export * from "./deleteClient";
export * from "./listClients";
export * from "./updateClient";
export * from "./getClient";

View File

@@ -0,0 +1,229 @@
import { db } from "@server/db";
import {
clients,
orgs,
roleClients,
sites,
userClients,
clientSites
} from "@server/db";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
import { and, count, eq, inArray, or, sql } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const listClientsParamsSchema = z
.object({
orgId: z.string()
})
.strict();
const listClientsSchema = z.object({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.number().int().positive()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.number().int().nonnegative())
});
function queryClients(orgId: string, accessibleClientIds: number[]) {
return db
.select({
clientId: clients.clientId,
orgId: clients.orgId,
name: clients.name,
pubKey: clients.pubKey,
subnet: clients.subnet,
megabytesIn: clients.megabytesIn,
megabytesOut: clients.megabytesOut,
orgName: orgs.name,
type: clients.type,
online: clients.online
})
.from(clients)
.leftJoin(orgs, eq(clients.orgId, orgs.orgId))
.where(
and(
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId)
)
);
}
async function getSiteAssociations(clientIds: number[]) {
if (clientIds.length === 0) return [];
return db
.select({
clientId: clientSites.clientId,
siteId: clientSites.siteId,
siteName: sites.name,
siteNiceId: sites.niceId
})
.from(clientSites)
.leftJoin(sites, eq(clientSites.siteId, sites.siteId))
.where(inArray(clientSites.clientId, clientIds));
}
export type ListClientsResponse = {
clients: Array<Awaited<ReturnType<typeof queryClients>>[0] & { sites: Array<{
siteId: number;
siteName: string | null;
siteNiceId: string | null;
}> }>;
pagination: { total: number; limit: number; offset: number };
};
registry.registerPath({
method: "get",
path: "/org/{orgId}/clients",
description: "List all clients for an organization.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
query: listClientsSchema,
params: listClientsParamsSchema
},
responses: {}
});
export async function listClients(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedQuery = listClientsSchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error)
)
);
}
const { limit, offset } = parsedQuery.data;
const parsedParams = listClientsParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error)
)
);
}
const { orgId } = parsedParams.data;
if (req.user && orgId && orgId !== req.userOrgId) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
)
);
}
let accessibleClients;
if (req.user) {
accessibleClients = await db
.select({
clientId: sql<number>`COALESCE(${userClients.clientId}, ${roleClients.clientId})`
})
.from(userClients)
.fullJoin(
roleClients,
eq(userClients.clientId, roleClients.clientId)
)
.where(
or(
eq(userClients.userId, req.user!.userId),
eq(roleClients.roleId, req.userOrgRoleId!)
)
);
} else {
accessibleClients = await db
.select({ clientId: clients.clientId })
.from(clients)
.where(eq(clients.orgId, orgId));
}
const accessibleClientIds = accessibleClients.map(
(client) => client.clientId
);
const baseQuery = queryClients(orgId, accessibleClientIds);
// Get client count
const countQuery = db
.select({ count: count() })
.from(clients)
.where(
and(
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId)
)
);
const clientsList = await baseQuery.limit(limit).offset(offset);
const totalCountResult = await countQuery;
const totalCount = totalCountResult[0].count;
// Get associated sites for all clients
const clientIds = clientsList.map(client => client.clientId);
const siteAssociations = await getSiteAssociations(clientIds);
// Group site associations by client ID
const sitesByClient = siteAssociations.reduce((acc, association) => {
if (!acc[association.clientId]) {
acc[association.clientId] = [];
}
acc[association.clientId].push({
siteId: association.siteId,
siteName: association.siteName,
siteNiceId: association.siteNiceId
});
return acc;
}, {} as Record<number, Array<{
siteId: number;
siteName: string | null;
siteNiceId: string | null;
}>>);
// Merge clients with their site associations
const clientsWithSites = clientsList.map(client => ({
...client,
sites: sitesByClient[client.clientId] || []
}));
return response<ListClientsResponse>(res, {
data: {
clients: clientsWithSites,
pagination: {
total: totalCount,
limit,
offset
}
},
success: true,
error: false,
message: "Clients retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,85 @@
import { Request, Response, NextFunction } from "express";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { generateId } from "@server/auth/sessions/app";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
export type PickClientDefaultsResponse = {
olmId: string;
olmSecret: string;
subnet: string;
};
const pickClientDefaultsSchema = z
.object({
orgId: z.string()
})
.strict();
registry.registerPath({
method: "get",
path: "/site/{siteId}/pick-client-defaults",
description: "Return pre-requisite data for creating a client.",
tags: [OpenAPITags.Client, OpenAPITags.Site],
request: {
params: pickClientDefaultsSchema
},
responses: {}
});
export async function pickClientDefaults(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = pickClientDefaultsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const olmId = generateId(15);
const secret = generateId(48);
const newSubnet = await getNextAvailableClientSubnet(orgId);
if (!newSubnet) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"No available subnet found"
)
);
}
const subnet = newSubnet.split("/")[0];
return response<PickClientDefaultsResponse>(res, {
data: {
olmId: olmId,
olmSecret: secret,
subnet: subnet
},
success: true,
error: false,
message: "Organization retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,225 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients, clientSites } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import {
addPeer as newtAddPeer,
deletePeer as newtDeletePeer
} from "../newt/peers";
import {
addPeer as olmAddPeer,
deletePeer as olmDeletePeer
} from "../olm/peers";
const updateClientParamsSchema = z
.object({
clientId: z.string().transform(Number).pipe(z.number().int().positive())
})
.strict();
const updateClientSchema = z
.object({
name: z.string().min(1).max(255).optional(),
siteIds: z
.array(z.string().transform(Number).pipe(z.number()))
.optional()
})
.strict();
export type UpdateClientBody = z.infer<typeof updateClientSchema>;
registry.registerPath({
method: "post",
path: "/client/{clientId}",
description: "Update a client by its client ID.",
tags: [OpenAPITags.Client],
request: {
params: updateClientParamsSchema,
body: {
content: {
"application/json": {
schema: updateClientSchema
}
}
}
},
responses: {}
});
export async function updateClient(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = updateClientSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { name, siteIds } = parsedBody.data;
const parsedParams = updateClientParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { clientId } = parsedParams.data;
// Fetch the client to make sure it exists and the user has access to it
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Client with ID ${clientId} not found`
)
);
}
if (siteIds) {
let sitesAdded = [];
let sitesRemoved = [];
// Fetch existing site associations
const existingSites = await db
.select({ siteId: clientSites.siteId })
.from(clientSites)
.where(eq(clientSites.clientId, clientId));
const existingSiteIds = existingSites.map((site) => site.siteId);
// Determine which sites were added and removed
sitesAdded = siteIds.filter(
(siteId) => !existingSiteIds.includes(siteId)
);
sitesRemoved = existingSiteIds.filter(
(siteId) => !siteIds.includes(siteId)
);
logger.info(
`Adding ${sitesAdded.length} new sites to client ${client.clientId}`
);
for (const siteId of sitesAdded) {
if (!client.subnet || !client.pubKey || !client.endpoint) {
logger.debug("Client subnet, pubKey or endpoint is not set");
continue;
}
const site = await newtAddPeer(siteId, {
publicKey: client.pubKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
endpoint: client.endpoint
});
if (!site) {
logger.debug("Failed to add peer to newt - missing site");
continue;
}
if (!site.endpoint || !site.publicKey) {
logger.debug("Site endpoint or publicKey is not set");
continue;
}
await olmAddPeer(client.clientId, {
siteId: siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort
});
}
logger.info(
`Removing ${sitesRemoved.length} sites from client ${client.clientId}`
);
for (const siteId of sitesRemoved) {
if (!client.pubKey) {
logger.debug("Client pubKey is not set");
continue;
}
const site = await newtDeletePeer(siteId, client.pubKey);
if (!site) {
logger.debug(
"Failed to delete peer from newt - missing site"
);
continue;
}
if (!site.endpoint || !site.publicKey) {
logger.debug("Site endpoint or publicKey is not set");
continue;
}
await olmDeletePeer(client.clientId, site.siteId, site.publicKey);
}
}
await db.transaction(async (trx) => {
// Update client name if provided
if (name) {
await trx
.update(clients)
.set({ name })
.where(eq(clients.clientId, clientId));
}
// Update site associations if provided
if (siteIds) {
// Delete existing site associations
await trx
.delete(clientSites)
.where(eq(clientSites.clientId, clientId));
// Create new site associations
if (siteIds.length > 0) {
await trx.insert(clientSites).values(
siteIds.map((siteId) => ({
clientId,
siteId
}))
);
}
}
// Fetch the updated client
const [updatedClient] = await trx
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
return response(res, {
data: updatedClient,
success: true,
error: false,
message: "Client updated successfully",
status: HttpCode.OK
});
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,287 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, Domain, domains, OrgDomains, orgDomains } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { subdomainSchema } from "@server/lib/schemas";
import { generateId } from "@server/auth/sessions/app";
import { eq, and } from "drizzle-orm";
import { isValidDomain } from "@server/lib/validators";
import { build } from "@server/build";
const paramsSchema = z
.object({
orgId: z.string()
})
.strict();
const bodySchema = z
.object({
type: z.enum(["ns", "cname", "wildcard"]),
baseDomain: subdomainSchema
})
.strict();
export type CreateDomainResponse = {
domainId: string;
nsRecords?: string[];
cnameRecords?: { baseDomain: string; value: string }[];
txtRecords?: { baseDomain: string; value: string }[];
};
// Helper to check if a domain is a subdomain or equal to another domain
function isSubdomainOrEqual(a: string, b: string): boolean {
const aParts = a.toLowerCase().split(".");
const bParts = b.toLowerCase().split(".");
if (aParts.length < bParts.length) return false;
return aParts.slice(-bParts.length).join(".") === bParts.join(".");
}
export async function createOrgDomain(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const { type, baseDomain } = parsedBody.data;
if (build == "oss") {
if (type !== "wildcard") {
return next(
createHttpError(
HttpCode.NOT_IMPLEMENTED,
"Creating NS or CNAME records is not supported"
)
);
}
} else if (build == "enterprise" || build == "saas") {
if (type !== "ns" && type !== "cname") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid domain type. Only NS, CNAME are allowed."
)
);
}
}
// Validate organization exists
if (!isValidDomain(baseDomain)) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid domain format")
);
}
let numOrgDomains: OrgDomains[] | undefined;
let cnameRecords: CreateDomainResponse["cnameRecords"];
let txtRecords: CreateDomainResponse["txtRecords"];
let nsRecords: CreateDomainResponse["nsRecords"];
let returned: Domain | undefined;
await db.transaction(async (trx) => {
const [existing] = await trx
.select()
.from(domains)
.where(
and(
eq(domains.baseDomain, baseDomain),
eq(domains.type, type)
)
)
.leftJoin(
orgDomains,
eq(orgDomains.domainId, domains.domainId)
);
if (existing) {
const {
domains: existingDomain,
orgDomains: existingOrgDomain
} = existing;
// user alrady added domain to this account
// always reject
if (existingOrgDomain?.orgId === orgId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Domain is already added to this org"
)
);
}
// domain already exists elsewhere
// check if it's already fully verified
if (existingDomain.verified) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Domain is already verified to an org"
)
);
}
}
// --- Domain overlap logic ---
// Only consider existing verified domains
const verifiedDomains = await trx
.select()
.from(domains)
.where(eq(domains.verified, true));
if (type == "cname") {
// Block if a verified CNAME exists at the same name
const cnameExists = verifiedDomains.some(
(d) => d.type === "cname" && d.baseDomain === baseDomain
);
if (cnameExists) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`A CNAME record already exists for ${baseDomain}. Only one CNAME record is allowed per domain.`
)
);
}
// Block if a verified NS exists at or below (same or subdomain)
const nsAtOrBelow = verifiedDomains.some(
(d) =>
d.type === "ns" &&
(isSubdomainOrEqual(baseDomain, d.baseDomain) ||
baseDomain === d.baseDomain)
);
if (nsAtOrBelow) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`A nameserver (NS) record exists at or below ${baseDomain}. You cannot create a CNAME record here.`
)
);
}
} else if (type == "ns") {
// Block if a verified NS exists at or below (same or subdomain)
const nsAtOrBelow = verifiedDomains.some(
(d) =>
d.type === "ns" &&
(isSubdomainOrEqual(baseDomain, d.baseDomain) ||
baseDomain === d.baseDomain)
);
if (nsAtOrBelow) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`A nameserver (NS) record already exists at or below ${baseDomain}. You cannot create another NS record here.`
)
);
}
} else if (type == "wildcard") {
// TODO: Figure out how to handle wildcards
}
const domainId = generateId(15);
const [insertedDomain] = await trx
.insert(domains)
.values({
domainId,
baseDomain,
type,
verified: build == "oss" ? true : false
})
.returning();
returned = insertedDomain;
// add domain to account
await trx
.insert(orgDomains)
.values({
orgId,
domainId
})
.returning();
// TODO: This needs to be cross region and not hardcoded
if (type === "ns") {
nsRecords = ["ns-east.fossorial.io", "ns-west.fossorial.io"];
} else if (type === "cname") {
cnameRecords = [
{
value: `${domainId}.cname.fossorial.io`,
baseDomain: baseDomain
},
{
value: `_acme-challenge.${domainId}.cname.fossorial.io`,
baseDomain: `_acme-challenge.${baseDomain}`
}
];
} else if (type === "wildcard") {
cnameRecords = [
{
value: `Server IP Address`,
baseDomain: `*.${baseDomain}`
},
{
value: `Server IP Address`,
baseDomain: `${baseDomain}`
}
];
}
numOrgDomains = await trx
.select()
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId));
});
if (!returned) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create domain"
)
);
}
return response<CreateDomainResponse>(res, {
data: {
domainId: returned.domainId,
cnameRecords,
txtRecords,
nsRecords
},
success: true,
error: false,
message: "Domain created successfully",
status: HttpCode.CREATED
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,72 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, domains, OrgDomains, orgDomains } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { and, eq } from "drizzle-orm";
const paramsSchema = z
.object({
domainId: z.string(),
orgId: z.string()
})
.strict();
export type DeleteAccountDomainResponse = {
success: boolean;
};
export async function deleteAccountDomain(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsed = paramsSchema.safeParse(req.params);
if (!parsed.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsed.error).toString()
)
);
}
const { domainId, orgId } = parsed.data;
let numOrgDomains: OrgDomains[] | undefined;
await db.transaction(async (trx) => {
await trx
.delete(orgDomains)
.where(
and(
eq(orgDomains.orgId, orgId),
eq(orgDomains.domainId, domainId)
)
);
await trx.delete(domains).where(eq(domains.domainId, domainId));
numOrgDomains = await trx
.select()
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId));
});
return response<DeleteAccountDomainResponse>(res, {
data: { success: true },
success: true,
error: false,
message: "Domain deleted from account successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -1 +1,4 @@
export * from "./listDomains";
export * from "./createOrgDomain";
export * from "./deleteOrgDomain";
export * from "./restartOrgDomain";

View File

@@ -37,7 +37,11 @@ async function queryDomains(orgId: string, limit: number, offset: number) {
const res = await db
.select({
domainId: domains.domainId,
baseDomain: domains.baseDomain
baseDomain: domains.baseDomain,
verified: domains.verified,
type: domains.type,
failed: domains.failed,
tries: domains.tries,
})
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId))
@@ -112,7 +116,7 @@ export async function listDomains(
},
success: true,
error: false,
message: "Users retrieved successfully",
message: "Domains retrieved successfully",
status: HttpCode.OK
});
} catch (error) {

View File

@@ -0,0 +1,57 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, domains } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { and, eq } from "drizzle-orm";
const paramsSchema = z
.object({
domainId: z.string(),
orgId: z.string()
})
.strict();
export type RestartOrgDomainResponse = {
success: boolean;
};
export async function restartOrgDomain(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsed = paramsSchema.safeParse(req.params);
if (!parsed.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsed.error).toString()
)
);
}
const { domainId, orgId } = parsed.data;
await db
.update(domains)
.set({ failed: false, tries: 0 })
.where(and(eq(domains.domainId, domainId)));
return response<RestartOrgDomainResponse>(res, {
data: { success: true },
success: true,
error: false,
message: "Domain restarted successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -8,6 +8,7 @@ import * as target from "./target";
import * as user from "./user";
import * as auth from "./auth";
import * as role from "./role";
import * as client from "./client";
import * as supporterKey from "./supporterKey";
import * as accessToken from "./accessToken";
import * as idp from "./idp";
@@ -28,14 +29,20 @@ import {
getUserOrgs,
verifyUserIsServerAdmin,
verifyIsLoggedInUser,
verifyApiKeyAccess
verifyClientAccess,
verifyApiKeyAccess,
verifyDomainAccess,
verifyClientsEnabled,
verifyUserHasAction,
verifyUserIsOrgOwner
} from "@server/middlewares";
import { verifyUserHasAction } from "../middlewares/verifyUserHasAction";
import { createStore } from "@server/lib/rateLimitStore";
import { ActionsEnum } from "@server/auth/actions";
import { verifyUserIsOrgOwner } from "../middlewares/verifyUserIsOrgOwner";
import { createNewt, getToken } from "./newt";
import { createNewt, getNewtToken } from "./newt";
import { getOlmToken } from "./olm";
import rateLimit from "express-rate-limit";
import createHttpError from "http-errors";
import { build } from "@server/build";
// Root routes
export const unauthenticated = Router();
@@ -48,8 +55,11 @@ unauthenticated.get("/", (_, res) => {
export const authenticated = Router();
authenticated.use(verifySessionUserMiddleware);
authenticated.get("/pick-org-defaults", org.pickOrgDefaults);
authenticated.get("/org/checkId", org.checkId);
authenticated.put("/org", getUserOrgs, org.createOrg);
if (build === "oss" || build === "enterprise") {
authenticated.put("/org", getUserOrgs, org.createOrg);
}
authenticated.get("/orgs", verifyUserIsServerAdmin, org.listOrgs);
authenticated.get("/user/:userId/orgs", verifyIsLoggedInUser, org.listUserOrgs);
@@ -104,6 +114,55 @@ authenticated.get(
verifyUserHasAction(ActionsEnum.getSite),
site.getSite
);
authenticated.get(
"/org/:orgId/pick-client-defaults",
verifyClientsEnabled,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createClient),
client.pickClientDefaults
);
authenticated.get(
"/org/:orgId/clients",
verifyClientsEnabled,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.listClients),
client.listClients
);
authenticated.get(
"/org/:orgId/client/:clientId",
verifyClientsEnabled,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.getClient),
client.getClient
);
authenticated.put(
"/org/:orgId/client",
verifyClientsEnabled,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createClient),
client.createClient
);
authenticated.delete(
"/client/:clientId",
verifyClientsEnabled,
verifyClientAccess,
verifyUserHasAction(ActionsEnum.deleteClient),
client.deleteClient
);
authenticated.post(
"/client/:clientId",
verifyClientsEnabled,
verifyClientAccess, // this will check if the user has access to the client
verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client
client.updateClient
);
// authenticated.get(
// "/site/:siteId/roles",
// verifySiteAccess,
@@ -698,6 +757,29 @@ authenticated.get(
apiKeys.getApiKey
);
authenticated.put(
`/org/:orgId/domain`,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createOrgDomain),
domain.createOrgDomain
);
authenticated.post(
`/org/:orgId/domain/:domainId/restart`,
verifyOrgAccess,
verifyDomainAccess,
verifyUserHasAction(ActionsEnum.restartOrgDomain),
domain.restartOrgDomain
);
authenticated.delete(
`/org/:orgId/domain/:domainId`,
verifyOrgAccess,
verifyDomainAccess,
verifyUserHasAction(ActionsEnum.deleteOrgDomain),
domain.deleteAccountDomain
);
// Auth routes
export const authRouter = Router();
unauthenticated.use("/auth", authRouter);
@@ -751,7 +833,20 @@ authRouter.post(
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
}
}),
getToken
getNewtToken
);
authRouter.post(
"/olm/get-token",
rateLimit({
windowMs: 15 * 60 * 1000,
max: 900,
keyGenerator: (req) => `newtGetToken:${req.body.newtId}`,
handler: (req, res, next) => {
const message = `You can only request an Olm token ${900} times every ${15} minutes. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
}
}),
getOlmToken
);
authRouter.post(
@@ -836,7 +931,8 @@ authRouter.post(
handler: (req, res, next) => {
const message = `You can only request an email verification code ${15} times every ${15} minutes. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
}
},
store: createStore()
}),
auth.requestEmailVerificationCode
);
@@ -856,7 +952,8 @@ authRouter.post(
handler: (req, res, next) => {
const message = `You can only request a password reset ${15} times every ${15} minutes. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
}
},
store: createStore()
}),
auth.requestPasswordReset
);
@@ -914,7 +1011,8 @@ authRouter.post(
handler: (req, res, next) => {
const message = `You can only request an email OTP ${15} times every ${15} minutes. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
}
},
store: createStore()
}),
resource.authWithWhitelist
);

View File

@@ -0,0 +1,160 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { clients, exitNodes, newts, olms, Site, sites, clientSites } from "@server/db";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
// Define Zod schema for request validation
const getAllRelaysSchema = z.object({
publicKey: z.string().optional(),
});
// Type for peer destination
interface PeerDestination {
destinationIP: string;
destinationPort: number;
}
// Updated mappings type to support multiple destinations per endpoint
interface ProxyMapping {
destinations: PeerDestination[];
}
export async function getAllRelays(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
// Validate request parameters
const parsedParams = getAllRelaysSchema.safeParse(req.body);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { publicKey } = parsedParams.data;
if (!publicKey) {
return next(createHttpError(HttpCode.BAD_REQUEST, 'publicKey is required'));
}
// Fetch exit node
let [exitNode] = await db.select().from(exitNodes).where(eq(exitNodes.publicKey, publicKey));
if (!exitNode) {
return next(createHttpError(HttpCode.NOT_FOUND, "Exit node not found"));
}
// Fetch sites for this exit node
const sitesRes = await db.select().from(sites).where(eq(sites.exitNodeId, exitNode.exitNodeId));
if (sitesRes.length === 0) {
return res.status(HttpCode.OK).send({
mappings: {}
});
}
// Initialize mappings object for multi-peer support
let mappings: { [key: string]: ProxyMapping } = {};
// Process each site
for (const site of sitesRes) {
if (!site.endpoint || !site.subnet || !site.listenPort) {
continue;
}
// Find all clients associated with this site through clientSites
const clientSitesRes = await db
.select()
.from(clientSites)
.where(eq(clientSites.siteId, site.siteId));
for (const clientSite of clientSitesRes) {
// Get client information
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientSite.clientId));
if (!client || !client.endpoint) {
continue;
}
// Add this site as a destination for the client
if (!mappings[client.endpoint]) {
mappings[client.endpoint] = { destinations: [] };
}
// Add site as a destination for this client
const destination: PeerDestination = {
destinationIP: site.subnet.split("/")[0],
destinationPort: site.listenPort
};
// Check if this destination is already in the array to avoid duplicates
const isDuplicate = mappings[client.endpoint].destinations.some(
dest => dest.destinationIP === destination.destinationIP &&
dest.destinationPort === destination.destinationPort
);
if (!isDuplicate) {
mappings[client.endpoint].destinations.push(destination);
}
}
// Also handle site-to-site communication (all sites in the same org)
if (site.orgId) {
const orgSites = await db
.select()
.from(sites)
.where(eq(sites.orgId, site.orgId));
for (const peer of orgSites) {
// Skip self
if (peer.siteId === site.siteId || !peer.endpoint || !peer.subnet || !peer.listenPort) {
continue;
}
// Add peer site as a destination for this site
if (!mappings[site.endpoint]) {
mappings[site.endpoint] = { destinations: [] };
}
const destination: PeerDestination = {
destinationIP: peer.subnet.split("/")[0],
destinationPort: peer.listenPort
};
// Check for duplicates
const isDuplicate = mappings[site.endpoint].destinations.some(
dest => dest.destinationIP === destination.destinationIP &&
dest.destinationPort === destination.destinationPort
);
if (!isDuplicate) {
mappings[site.endpoint].destinations.push(destination);
}
}
}
}
logger.debug(`Returning mappings for ${Object.keys(mappings).length} endpoints`);
return res.status(HttpCode.OK).send({ mappings });
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred..."
)
);
}
}

View File

@@ -53,7 +53,7 @@ export async function getConfig(
}
// Fetch exit node
let exitNodeQuery = await db
const exitNodeQuery = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.publicKey, publicKey));
@@ -68,6 +68,10 @@ export async function getConfig(
subEndpoint = await getUniqueExitNodeEndpointName();
}
const exitNodeName =
config.getRawConfig().gerbil.exit_node_name ||
`Exit Node ${publicKey.slice(0, 8)}`;
// create a new exit node
exitNode = await db
.insert(exitNodes)
@@ -77,7 +81,7 @@ export async function getConfig(
address,
listenPort,
reachableAt,
name: `Exit Node ${publicKey.slice(0, 8)}`
name: exitNodeName
})
.returning()
.execute();

View File

@@ -1,2 +1,4 @@
export * from "./getConfig";
export * from "./receiveBandwidth";
export * from "./updateHolePunch";
export * from "./getAllRelays";

View File

@@ -1,12 +1,15 @@
import { Request, Response, NextFunction } from "express";
import { eq } from "drizzle-orm";
import { sites, } from "@server/db";
import { eq, and, lt, inArray, sql } from "drizzle-orm";
import { sites } from "@server/db";
import { db } from "@server/db";
import logger from "@server/logger";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
// Track sites that are already offline to avoid unnecessary queries
const offlineSites = new Set<string>();
interface PeerBandwidth {
publicKey: string;
bytesIn: number;
@@ -25,47 +28,101 @@ export const receiveBandwidth = async (
throw new Error("Invalid bandwidth data");
}
const currentTime = new Date();
const oneMinuteAgo = new Date(currentTime.getTime() - 60000); // 1 minute ago
logger.debug(`Received data: ${JSON.stringify(bandwidthData)}`);
await db.transaction(async (trx) => {
for (const peer of bandwidthData) {
const { publicKey, bytesIn, bytesOut } = peer;
// First, handle sites that are actively reporting bandwidth
const activePeers = bandwidthData.filter(peer => peer.bytesIn > 0); // Bytesout will have data as it tries to send keep alive messages
const [site] = await trx
.select()
.from(sites)
.where(eq(sites.pubKey, publicKey))
.limit(1);
if (activePeers.length > 0) {
// Remove any active peers from offline tracking since they're sending data
activePeers.forEach(peer => offlineSites.delete(peer.publicKey));
if (!site) {
logger.warn(`Site not found for public key: ${publicKey}`);
continue;
}
let online = site.online;
// Aggregate usage data by organization
const orgUsageMap = new Map<string, number>();
const orgUptimeMap = new Map<string, number>();
// if the bandwidth for the site is > 0 then set it to online. if it has been less than 0 (no update) for 5 minutes then set it to offline
if (bytesIn > 0 || bytesOut > 0) {
online = true;
} else if (site.lastBandwidthUpdate) {
const lastBandwidthUpdate = new Date(
site.lastBandwidthUpdate
);
const currentTime = new Date();
const diff =
currentTime.getTime() - lastBandwidthUpdate.getTime();
if (diff < 300000) {
online = false;
// Update all active sites with bandwidth data and get the site data in one operation
const updatedSites = [];
for (const peer of activePeers) {
const updatedSite = await trx
.update(sites)
.set({
megabytesOut: sql`${sites.megabytesOut} + ${peer.bytesIn}`,
megabytesIn: sql`${sites.megabytesIn} + ${peer.bytesOut}`,
lastBandwidthUpdate: currentTime.toISOString(),
online: true
})
.where(eq(sites.pubKey, peer.publicKey))
.returning({
online: sites.online,
orgId: sites.orgId,
siteId: sites.siteId,
lastBandwidthUpdate: sites.lastBandwidthUpdate,
});
if (updatedSite.length > 0) {
updatedSites.push({ ...updatedSite[0], peer });
}
}
// Update the site's bandwidth usage
await trx
.update(sites)
.set({
megabytesOut: (site.megabytesOut || 0) + bytesIn,
megabytesIn: (site.megabytesIn || 0) + bytesOut,
lastBandwidthUpdate: new Date().toISOString(),
online
})
.where(eq(sites.siteId, site.siteId));
// Calculate org usage aggregations using the updated site data
for (const { peer, ...site } of updatedSites) {
// Aggregate bandwidth usage for the org
const totalBandwidth = peer.bytesIn + peer.bytesOut;
const currentOrgUsage = orgUsageMap.get(site.orgId) || 0;
orgUsageMap.set(site.orgId, currentOrgUsage + totalBandwidth);
// Add 10 seconds of uptime for each active site
const currentOrgUptime = orgUptimeMap.get(site.orgId) || 0;
orgUptimeMap.set(site.orgId, currentOrgUptime + 10 / 60); // Store in minutes and jut add 10 seconds
}
}
// Handle sites that reported zero bandwidth but need online status updated
const zeroBandwidthPeers = bandwidthData.filter(peer =>
peer.bytesIn === 0 && !offlineSites.has(peer.publicKey) // Bytesout will have data as it tries to send keep alive messages
);
if (zeroBandwidthPeers.length > 0) {
const zeroBandwidthSites = await trx
.select()
.from(sites)
.where(inArray(sites.pubKey, zeroBandwidthPeers.map(p => p.publicKey)));
for (const site of zeroBandwidthSites) {
let newOnlineStatus = site.online;
// Check if site should go offline based on last bandwidth update WITH DATA
if (site.lastBandwidthUpdate) {
const lastUpdateWithData = new Date(site.lastBandwidthUpdate);
if (lastUpdateWithData < oneMinuteAgo) {
newOnlineStatus = false;
}
} else {
// No previous data update recorded, set to offline
newOnlineStatus = false;
}
// Always update lastBandwidthUpdate to show this instance is receiving reports
// Only update online status if it changed
if (site.online !== newOnlineStatus) {
await trx
.update(sites)
.set({
online: newOnlineStatus
})
.where(eq(sites.siteId, site.siteId));
// If site went offline, add it to our tracking set
if (!newOnlineStatus && site.pubKey) {
offlineSites.add(site.pubKey);
}
}
}
}
});
@@ -73,7 +130,7 @@ export const receiveBandwidth = async (
data: {},
success: true,
error: false,
message: "Organization retrieved successfully",
message: "Bandwidth data updated successfully",
status: HttpCode.OK
});
} catch (error) {

View File

@@ -0,0 +1,242 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { clients, newts, olms, Site, sites, clientSites } from "@server/db";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
// Define Zod schema for request validation
const updateHolePunchSchema = z.object({
olmId: z.string().optional(),
newtId: z.string().optional(),
token: z.string(),
ip: z.string(),
port: z.number(),
timestamp: z.number()
});
// New response type with multi-peer destination support
interface PeerDestination {
destinationIP: string;
destinationPort: number;
}
export async function updateHolePunch(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
// Validate request parameters
const parsedParams = updateHolePunchSchema.safeParse(req.body);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { olmId, newtId, ip, port, timestamp, token } = parsedParams.data;
let currentSiteId: number | undefined;
let destinations: PeerDestination[] = [];
if (olmId) {
logger.debug(`Got hole punch with ip: ${ip}, port: ${port} for olmId: ${olmId}`);
const { session, olm: olmSession } =
await validateOlmSessionToken(token);
if (!session || !olmSession) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
if (olmId !== olmSession.olmId) {
logger.warn(`Olm ID mismatch: ${olmId} !== ${olmSession.olmId}`);
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.olmId, olmId));
if (!olm || !olm.clientId) {
logger.warn(`Olm not found: ${olmId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "Olm not found")
);
}
const [client] = await db
.update(clients)
.set({
endpoint: `${ip}:${port}`,
lastHolePunch: timestamp
})
.where(eq(clients.clientId, olm.clientId))
.returning();
if (!client) {
logger.warn(`Client not found for olm: ${olmId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "Client not found")
);
}
// Get all sites that this client is connected to
const clientSitePairs = await db
.select()
.from(clientSites)
.where(eq(clientSites.clientId, client.clientId));
if (clientSitePairs.length === 0) {
logger.warn(`No sites found for client: ${client.clientId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "No sites found for client")
);
}
// Get all sites details
const siteIds = clientSitePairs.map(pair => pair.siteId);
for (const siteId of siteIds) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId));
if (site && site.subnet && site.listenPort) {
destinations.push({
destinationIP: site.subnet.split("/")[0],
destinationPort: site.listenPort
});
}
}
} else if (newtId) {
const { session, newt: newtSession } =
await validateNewtSessionToken(token);
if (!session || !newtSession) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
if (newtId !== newtSession.newtId) {
logger.warn(`Newt ID mismatch: ${newtId} !== ${newtSession.newtId}`);
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Unauthorized")
);
}
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.newtId, newtId));
if (!newt || !newt.siteId) {
logger.warn(`Newt not found: ${newtId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "New not found")
);
}
currentSiteId = newt.siteId;
// Update the current site with the new endpoint
const [updatedSite] = await db
.update(sites)
.set({
endpoint: `${ip}:${port}`,
lastHolePunch: timestamp
})
.where(eq(sites.siteId, newt.siteId))
.returning();
if (!updatedSite || !updatedSite.subnet) {
logger.warn(`Site not found: ${newt.siteId}`);
return next(
createHttpError(HttpCode.NOT_FOUND, "Site not found")
);
}
// Find all clients that connect to this site
const sitesClientPairs = await db
.select()
.from(clientSites)
.where(eq(clientSites.siteId, newt.siteId));
// Get client details for each client
for (const pair of sitesClientPairs) {
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, pair.clientId));
if (client && client.endpoint) {
const [host, portStr] = client.endpoint.split(':');
if (host && portStr) {
destinations.push({
destinationIP: host,
destinationPort: parseInt(portStr, 10)
});
}
}
}
// If this is a newt/site, also add other sites in the same org
// if (updatedSite.orgId) {
// const orgSites = await db
// .select()
// .from(sites)
// .where(eq(sites.orgId, updatedSite.orgId));
// for (const site of orgSites) {
// // Don't add the current site to the destinations
// if (site.siteId !== currentSiteId && site.subnet && site.endpoint && site.listenPort) {
// const [host, portStr] = site.endpoint.split(':');
// if (host && portStr) {
// destinations.push({
// destinationIP: host,
// destinationPort: site.listenPort
// });
// }
// }
// }
// }
}
// if (destinations.length === 0) {
// logger.warn(
// `No peer destinations found for olmId: ${olmId} or newtId: ${newtId}`
// );
// return next(createHttpError(HttpCode.NOT_FOUND, "No peer destinations found"));
// }
// Return the new multi-peer structure
return res.status(HttpCode.OK).send({
destinations: destinations
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred..."
)
);
}
}

View File

@@ -11,6 +11,7 @@ import {
idpOidcConfig,
idpOrg,
orgs,
Role,
roles,
userOrgs,
users
@@ -307,6 +308,8 @@ export async function validateOidcCallback(
let existingUserId = existingUser?.userId;
let orgUserCounts: { orgId: string; userCount: number }[] = [];
// sync the user with the orgs and roles
await db.transaction(async (trx) => {
let userId = existingUser?.userId;
@@ -410,6 +413,19 @@ export async function validateOidcCallback(
}))
);
}
// Loop through all the orgs and get the total number of users from the userOrgs table
for (const org of currentUserOrgs) {
const userCount = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, org.orgId));
orgUserCounts.push({
orgId: org.orgId,
userCount: userCount.length
});
}
});
const token = generateSessionToken();

View File

@@ -51,6 +51,8 @@ internalRouter.use("/gerbil", gerbilRouter);
gerbilRouter.post("/get-config", gerbil.getConfig);
gerbilRouter.post("/receive-bandwidth", gerbil.receiveBandwidth);
gerbilRouter.post("/update-hole-punch", gerbil.updateHolePunch);
gerbilRouter.post("/get-all-relays", gerbil.getAllRelays);
// Badger routes
const badgerRouter = Router();

View File

@@ -1,12 +1,29 @@
import {
handleRegisterMessage,
handleNewtRegisterMessage,
handleReceiveBandwidthMessage,
handleGetConfigMessage,
handleDockerStatusMessage,
handleDockerContainersMessage
handleDockerContainersMessage,
handleNewtPingRequestMessage
} from "./newt";
import {
handleOlmRegisterMessage,
handleOlmRelayMessage,
handleOlmPingMessage,
startOfflineChecker
} from "./olm";
import { MessageHandler } from "./ws";
export const messageHandlers: Record<string, MessageHandler> = {
"newt/wg/register": handleRegisterMessage,
"newt/wg/register": handleNewtRegisterMessage,
"olm/wg/register": handleOlmRegisterMessage,
"newt/wg/get-config": handleGetConfigMessage,
"newt/receive-bandwidth": handleReceiveBandwidthMessage,
"olm/wg/relay": handleOlmRelayMessage,
"olm/ping": handleOlmPingMessage,
"newt/socket/status": handleDockerStatusMessage,
"newt/socket/containers": handleDockerContainersMessage
"newt/socket/containers": handleDockerContainersMessage,
"newt/ping/request": handleNewtPingRequestMessage,
};
startOfflineChecker(); // this is to handle the offline check for olms

View File

@@ -24,7 +24,7 @@ export const newtGetTokenBodySchema = z.object({
export type NewtGetTokenBody = z.infer<typeof newtGetTokenBodySchema>;
export async function getToken(
export async function getNewtToken(
req: Request,
res: Response,
next: NextFunction

View File

@@ -0,0 +1,165 @@
import { z } from "zod";
import { MessageHandler } from "../ws";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { db } from "@server/db";
import { clients, clientSites, Newt, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { updatePeer } from "../olm/peers";
const inputSchema = z.object({
publicKey: z.string(),
port: z.number().int().positive()
});
type Input = z.infer<typeof inputSchema>;
export const handleGetConfigMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
const now = new Date().getTime() / 1000;
logger.debug("Handling Newt get config message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
return;
}
const parsed = inputSchema.safeParse(message.data);
if (!parsed.success) {
logger.error(
"handleGetConfigMessage: Invalid input: " +
fromError(parsed.error).toString()
);
return;
}
const { publicKey, port } = message.data as Input;
const siteId = newt.siteId;
// Get the current site data
const [existingSite] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId));
if (!existingSite) {
logger.warn("handleGetConfigMessage: Site not found");
return;
}
// we need to wait for hole punch success
if (!existingSite.endpoint) {
logger.warn(`Site ${existingSite.siteId} has no endpoint, skipping`);
return;
}
if (existingSite.publicKey !== publicKey) {
// TODO: somehow we should make sure a recent hole punch has happened if this occurs (hole punch could be from the last restart if done quickly)
}
if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 6) {
logger.warn(
`Site ${existingSite.siteId} last hole punch is too old, skipping`
);
return;
}
// update the endpoint and the public key
const [site] = await db
.update(sites)
.set({
publicKey,
listenPort: port
})
.where(eq(sites.siteId, siteId))
.returning();
if (!site) {
logger.error("handleGetConfigMessage: Failed to update site");
return;
}
// Get all clients connected to this site
const clientsRes = await db
.select()
.from(clients)
.innerJoin(clientSites, eq(clients.clientId, clientSites.clientId))
.where(eq(clientSites.siteId, siteId));
// Prepare peers data for the response
const peers = await Promise.all(
clientsRes
.filter((client) => {
if (!client.clients.pubKey) {
return false;
}
if (!client.clients.subnet) {
return false;
}
if (!client.clients.endpoint) {
return false;
}
if (!client.clients.online) {
return false;
}
return true;
})
.map(async (client) => {
// Add or update this peer on the olm if it is connected
try {
if (site.endpoint && site.publicKey) {
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort
});
}
} catch (error) {
logger.error(
`Failed to add/update peer ${client.clients.pubKey} to newt ${newt.newtId}: ${error}`
);
}
return {
publicKey: client.clients.pubKey!,
allowedIps: [`${client.clients.subnet.split('/')[0]}/32`], // we want to only allow from that client
endpoint: client.clientSites.isRelayed
? ""
: client.clients.endpoint! // if its relayed it should be localhost
};
})
);
// Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null);
// Build the configuration response
const configResponse = {
ipAddress: site.address,
peers: validPeers
};
logger.debug("Sending config: ", configResponse);
return {
message: {
type: "newt/wg/receive-config",
data: {
...configResponse
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -0,0 +1,89 @@
import { db, sites } from "@server/db";
import { MessageHandler } from "../ws";
import { exitNodes, Newt } from "@server/db";
import logger from "@server/logger";
import config from "@server/lib/config";
import { ne, eq, or, and, count } from "drizzle-orm";
export const handleNewtPingRequestMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
logger.info("Handling ping request newt message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
// TODO: pick which nodes to send and ping better than just all of them
let exitNodesList = await db
.select()
.from(exitNodes);
exitNodesList = exitNodesList.filter((node) => node.maxConnections !== 0);
let lastExitNodeId = null;
if (newt.siteId) {
const [lastExitNode] = await db
.select()
.from(sites)
.where(eq(sites.siteId, newt.siteId))
.limit(1);
lastExitNodeId = lastExitNode?.exitNodeId || null;
}
const exitNodesPayload = await Promise.all(
exitNodesList.map(async (node) => {
// (MAX_CONNECTIONS - current_connections) / MAX_CONNECTIONS)
// higher = more desirable
// like saying, this node has x% of its capacity left
let weight = 1;
const maxConnections = node.maxConnections;
if (maxConnections !== null && maxConnections !== undefined) {
const [currentConnections] = await db
.select({
count: count()
})
.from(sites)
.where(
and(
eq(sites.exitNodeId, node.exitNodeId),
eq(sites.online, true)
)
);
if (currentConnections.count >= maxConnections) {
return null
}
weight =
(maxConnections - currentConnections.count) /
maxConnections;
}
return {
exitNodeId: node.exitNodeId,
exitNodeName: node.name,
endpoint: node.endpoint,
weight,
wasPreviouslyConnected: node.exitNodeId === lastExitNodeId
};
})
);
// filter out null values
const filteredExitNodes = exitNodesPayload.filter((node) => node !== null);
return {
message: {
type: "newt/ping/exitNodes",
data: {
exitNodes: filteredExitNodes
}
},
broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast
};
};

View File

@@ -0,0 +1,358 @@
import { db, newts } from "@server/db";
import { MessageHandler } from "../ws";
import { exitNodes, Newt, resources, sites, Target, targets } from "@server/db";
import { eq, and, sql, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger";
import config from "@server/lib/config";
import {
findNextAvailableCidr,
getNextAvailableClientSubnet
} from "@server/lib/ip";
export type ExitNodePingResult = {
exitNodeId: number;
latencyMs: number;
weight: number;
error?: string;
exitNodeName: string;
endpoint: string;
wasPreviouslyConnected: boolean;
};
export const handleNewtRegisterMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
logger.info("Handling register newt message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
return;
}
const siteId = newt.siteId;
const { publicKey, pingResults, newtVersion, backwardsCompatible } =
message.data;
if (!publicKey) {
logger.warn("Public key not provided");
return;
}
if (backwardsCompatible) {
logger.debug(
"Backwards compatible mode detecting - not sending connect message and waiting for ping response."
);
return;
}
let exitNodeId: number | undefined;
if (pingResults) {
const bestPingResult = selectBestExitNode(
pingResults as ExitNodePingResult[]
);
if (!bestPingResult) {
logger.warn("No suitable exit node found based on ping results");
return;
}
exitNodeId = bestPingResult.exitNodeId;
}
if (newtVersion) {
// update the newt version in the database
await db
.update(newts)
.set({
version: newtVersion as string
})
.where(eq(newts.newtId, newt.newtId));
}
const [oldSite] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!oldSite || !oldSite.exitNodeId) {
logger.warn("Site not found or does not have exit node");
return;
}
let siteSubnet = oldSite.subnet;
let exitNodeIdToQuery = oldSite.exitNodeId;
if (exitNodeId && (oldSite.exitNodeId !== exitNodeId || !oldSite.subnet)) {
// This effectively moves the exit node to the new one
exitNodeIdToQuery = exitNodeId; // Use the provided exitNodeId if it differs from the site's exitNodeId
const sitesQuery = await db
.select({
subnet: sites.subnet
})
.from(sites)
.where(eq(sites.exitNodeId, exitNodeId));
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, exitNodeIdToQuery))
.limit(1);
const blockSize = config.getRawConfig().gerbil.site_block_size;
const subnets = sitesQuery.map((site) => site.subnet).filter((subnet) => subnet !== null);
subnets.push(exitNode.address.replace(/\/\d+$/, `/${blockSize}`));
const newSubnet = findNextAvailableCidr(
subnets,
blockSize,
exitNode.address
);
if (!newSubnet) {
logger.error("No available subnets found for the new exit node");
return;
}
siteSubnet = newSubnet;
await db
.update(sites)
.set({
pubKey: publicKey,
exitNodeId: exitNodeId,
subnet: newSubnet
})
.where(eq(sites.siteId, siteId))
.returning();
} else {
await db
.update(sites)
.set({
pubKey: publicKey
})
.where(eq(sites.siteId, siteId))
.returning();
}
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, exitNodeIdToQuery))
.limit(1);
if (oldSite.pubKey && oldSite.pubKey !== publicKey) {
logger.info("Public key mismatch. Deleting old peer...");
await deletePeer(oldSite.exitNodeId, oldSite.pubKey);
}
if (!siteSubnet) {
logger.warn("Site has no subnet");
return;
}
// add the peer to the exit node
await addPeer(exitNodeIdToQuery, {
publicKey: publicKey,
allowedIps: [siteSubnet]
});
// Improved version
const allResources = await db.transaction(async (tx) => {
// First get all resources for the site
const resourcesList = await tx
.select({
resourceId: resources.resourceId,
subdomain: resources.subdomain,
fullDomain: resources.fullDomain,
ssl: resources.ssl,
blockAccess: resources.blockAccess,
sso: resources.sso,
emailWhitelistEnabled: resources.emailWhitelistEnabled,
http: resources.http,
proxyPort: resources.proxyPort,
protocol: resources.protocol
})
.from(resources)
.where(eq(resources.siteId, siteId));
// Get all enabled targets for these resources in a single query
const resourceIds = resourcesList.map((r) => r.resourceId);
const allTargets =
resourceIds.length > 0
? await tx
.select({
resourceId: targets.resourceId,
targetId: targets.targetId,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled
})
.from(targets)
.where(
and(
inArray(targets.resourceId, resourceIds),
eq(targets.enabled, true)
)
)
: [];
// Combine the data in JS instead of using SQL for the JSON
return resourcesList.map((resource) => ({
...resource,
targets: allTargets.filter(
(target) => target.resourceId === resource.resourceId
)
}));
});
const { tcpTargets, udpTargets } = allResources.reduce(
(acc, resource) => {
// Skip resources with no targets
if (!resource.targets?.length) return acc;
// Format valid targets into strings
const formattedTargets = resource.targets
.filter(
(target: Target) =>
target?.internalPort && target?.ip && target?.port
)
.map(
(target: Target) =>
`${target.internalPort}:${target.ip}:${target.port}`
);
// Add to the appropriate protocol array
if (resource.protocol === "tcp") {
acc.tcpTargets.push(...formattedTargets);
} else {
acc.udpTargets.push(...formattedTargets);
}
return acc;
},
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
);
return {
message: {
type: "newt/wg/connect",
data: {
endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`,
publicKey: exitNode.publicKey,
serverIP: exitNode.address.split("/")[0],
tunnelIP: siteSubnet.split("/")[0],
targets: {
udp: udpTargets,
tcp: tcpTargets
}
}
},
broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast
};
};
/**
* Selects the most suitable exit node from a list of ping results.
*
* The selection algorithm follows these steps:
*
* 1. **Filter Invalid Nodes**: Excludes nodes with errors or zero weight.
*
* 2. **Sort by Latency**: Sorts valid nodes in ascending order of latency.
*
* 3. **Preferred Selection**:
* - If the lowest-latency node has sufficient capacity (≥10% weight),
* check if a previously connected node is also acceptable.
* - The previously connected node is preferred if its latency is within
* 30ms or 15% of the best nodes latency.
*
* 4. **Fallback to Next Best**:
* - If the lowest-latency node is under capacity, find the next node
* with acceptable capacity.
*
* 5. **Final Fallback**:
* - If no nodes meet the capacity threshold, fall back to the node
* with the highest weight (i.e., most available capacity).
*
*/
function selectBestExitNode(
pingResults: ExitNodePingResult[]
): ExitNodePingResult | null {
const MIN_CAPACITY_THRESHOLD = 0.1;
const LATENCY_TOLERANCE_MS = 30;
const LATENCY_TOLERANCE_PERCENT = 0.15;
// Filter out invalid nodes
const validNodes = pingResults.filter((n) => !n.error && n.weight > 0);
if (validNodes.length === 0) {
logger.error("No valid exit nodes available");
return null;
}
// Sort by latency (ascending)
const sortedNodes = validNodes
.slice()
.sort((a, b) => a.latencyMs - b.latencyMs);
const lowestLatencyNode = sortedNodes[0];
logger.info(
`Lowest latency node: ${lowestLatencyNode.exitNodeName} (${lowestLatencyNode.latencyMs} ms, weight=${lowestLatencyNode.weight.toFixed(2)})`
);
// If lowest latency node has enough capacity, check if previously connected node is acceptable
if (lowestLatencyNode.weight >= MIN_CAPACITY_THRESHOLD) {
const previouslyConnectedNode = sortedNodes.find(
(n) =>
n.wasPreviouslyConnected && n.weight >= MIN_CAPACITY_THRESHOLD
);
if (previouslyConnectedNode) {
const latencyDiff =
previouslyConnectedNode.latencyMs - lowestLatencyNode.latencyMs;
const percentDiff = latencyDiff / lowestLatencyNode.latencyMs;
if (
latencyDiff <= LATENCY_TOLERANCE_MS ||
percentDiff <= LATENCY_TOLERANCE_PERCENT
) {
logger.info(
`Sticking with previously connected node: ${previouslyConnectedNode.exitNodeName} ` +
`(${previouslyConnectedNode.latencyMs} ms), latency diff = ${latencyDiff.toFixed(1)}ms ` +
`/ ${(percentDiff * 100).toFixed(1)}%.`
);
return previouslyConnectedNode;
}
}
return lowestLatencyNode;
}
// Otherwise, find the next node (after the lowest) that has enough capacity
for (let i = 1; i < sortedNodes.length; i++) {
const node = sortedNodes[i];
if (node.weight >= MIN_CAPACITY_THRESHOLD) {
logger.info(
`Lowest latency node under capacity. Using next best: ${node.exitNodeName} ` +
`(${node.latencyMs} ms, weight=${node.weight.toFixed(2)})`
);
return node;
}
}
// Fallback: pick the highest weight node
const fallbackNode = validNodes.reduce((a, b) =>
a.weight > b.weight ? a : b
);
logger.warn(
`No nodes with ≥10% weight. Falling back to highest capacity node: ${fallbackNode.exitNodeName}`
);
return fallbackNode;
}

View File

@@ -0,0 +1,52 @@
import { db } from "@server/db";
import { MessageHandler } from "../ws";
import { clients, Newt } from "@server/db";
import { eq } from "drizzle-orm";
import logger from "@server/logger";
interface PeerBandwidth {
publicKey: string;
bytesIn: number;
bytesOut: number;
}
export const handleReceiveBandwidthMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
if (!message.data.bandwidthData) {
logger.warn("No bandwidth data provided");
}
const bandwidthData: PeerBandwidth[] = message.data.bandwidthData;
if (!Array.isArray(bandwidthData)) {
throw new Error("Invalid bandwidth data");
}
await db.transaction(async (trx) => {
for (const peer of bandwidthData) {
const { publicKey, bytesIn, bytesOut } = peer;
// Find the client by public key
const [client] = await trx
.select()
.from(clients)
.where(eq(clients.pubKey, publicKey))
.limit(1);
if (!client) {
continue;
}
// Update the client's bandwidth usage
await trx
.update(clients)
.set({
megabytesOut: (client.megabytesIn || 0) + bytesIn,
megabytesIn: (client.megabytesOut || 0) + bytesOut,
lastBandwidthUpdate: new Date().toISOString(),
})
.where(eq(clients.clientId, client.clientId));
}
});
};

View File

@@ -1,174 +0,0 @@
import { db } from "@server/db";
import { MessageHandler } from "../ws";
import {
exitNodes,
resources,
sites,
Target,
targets
} from "@server/db";
import { eq, and, sql, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger";
export const handleRegisterMessage: MessageHandler = async (context) => {
const { message, newt, sendToClient } = context;
logger.info("Handling register message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
return;
}
const siteId = newt.siteId;
const { publicKey } = message.data;
if (!publicKey) {
logger.warn("Public key not provided");
return;
}
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site || !site.exitNodeId) {
logger.warn("Site not found or does not have exit node");
return;
}
await db
.update(sites)
.set({
pubKey: publicKey
})
.where(eq(sites.siteId, siteId))
.returning();
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
if (site.pubKey && site.pubKey !== publicKey) {
logger.info("Public key mismatch. Deleting old peer...");
await deletePeer(site.exitNodeId, site.pubKey);
}
if (!site.subnet) {
logger.warn("Site has no subnet");
return;
}
// add the peer to the exit node
await addPeer(site.exitNodeId, {
publicKey: publicKey,
allowedIps: [site.subnet]
});
// Improved version
const allResources = await db.transaction(async (tx) => {
// First get all resources for the site
const resourcesList = await tx
.select({
resourceId: resources.resourceId,
subdomain: resources.subdomain,
fullDomain: resources.fullDomain,
ssl: resources.ssl,
blockAccess: resources.blockAccess,
sso: resources.sso,
emailWhitelistEnabled: resources.emailWhitelistEnabled,
http: resources.http,
proxyPort: resources.proxyPort,
protocol: resources.protocol
})
.from(resources)
.where(eq(resources.siteId, siteId));
// Get all enabled targets for these resources in a single query
const resourceIds = resourcesList.map((r) => r.resourceId);
const allTargets =
resourceIds.length > 0
? await tx
.select({
resourceId: targets.resourceId,
targetId: targets.targetId,
ip: targets.ip,
method: targets.method,
port: targets.port,
internalPort: targets.internalPort,
enabled: targets.enabled
})
.from(targets)
.where(
and(
inArray(targets.resourceId, resourceIds),
eq(targets.enabled, true)
)
)
: [];
// Combine the data in JS instead of using SQL for the JSON
return resourcesList.map((resource) => ({
...resource,
targets: allTargets.filter(
(target) => target.resourceId === resource.resourceId
)
}));
});
const { tcpTargets, udpTargets } = allResources.reduce(
(acc, resource) => {
// Skip resources with no targets
if (!resource.targets?.length) return acc;
// Format valid targets into strings
const formattedTargets = resource.targets
.filter(
(target: Target) =>
target?.internalPort && target?.ip && target?.port
)
.map(
(target: Target) =>
`${target.internalPort}:${target.ip}:${target.port}`
);
// Add to the appropriate protocol array
if (resource.protocol === "tcp") {
acc.tcpTargets.push(...formattedTargets);
} else {
acc.udpTargets.push(...formattedTargets);
}
return acc;
},
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
);
return {
message: {
type: "newt/wg/connect",
data: {
endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`,
publicKey: exitNode.publicKey,
serverIP: exitNode.address.split("/")[0],
tunnelIP: site.subnet.split("/")[0],
targets: {
udp: udpTargets,
tcp: tcpTargets
}
}
},
broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast
};
};

View File

@@ -1,9 +1,11 @@
import { MessageHandler } from "../ws";
import logger from "@server/logger";
import { dockerSocketCache } from "./dockerSocket";
import { Newt } from "@server/db";
export const handleDockerStatusMessage: MessageHandler = async (context) => {
const { message, newt } = context;
const { message, client, sendToClient } = context;
const newt = client as Newt;
logger.info("Handling Docker socket check response");
@@ -33,7 +35,8 @@ export const handleDockerStatusMessage: MessageHandler = async (context) => {
export const handleDockerContainersMessage: MessageHandler = async (
context
) => {
const { message, newt } = context;
const { message, client, sendToClient } = context;
const newt = client as Newt;
logger.info("Handling Docker containers response");

View File

@@ -1,4 +1,7 @@
export * from "./createNewt";
export * from "./getToken";
export * from "./handleRegisterMessage";
export * from "./handleSocketMessages";
export * from "./getNewtToken";
export * from "./handleNewtRegisterMessage";
export * from "./handleReceiveBandwidthMessage";
export * from "./handleGetConfigMessage";
export * from "./handleSocketMessages";
export * from "./handleNewtPingRequestMessage";

View File

@@ -0,0 +1,114 @@
import { db } from "@server/db";
import { newts, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { sendToClient } from "../ws";
import logger from "@server/logger";
export async function addPeer(
siteId: number,
peer: {
publicKey: string;
allowedIps: string[];
endpoint: string;
}
) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Exit node with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Site found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: "newt/wg/peer/add",
data: peer
});
logger.info(`Added peer ${peer.publicKey} to newt ${newt.newtId}`);
return site;
}
export async function deletePeer(siteId: number, publicKey: string) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: "newt/wg/peer/remove",
data: {
publicKey
}
});
logger.info(`Deleted peer ${publicKey} from newt ${newt.newtId}`);
return site;
}
export async function updatePeer(
siteId: number,
publicKey: string,
peer: {
allowedIps?: string[];
endpoint?: string;
}
) {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: "newt/wg/peer/update",
data: {
publicKey,
...peer
}
});
logger.info(`Updated peer ${publicKey} on newt ${newt.newtId}`);
return site;
}

View File

@@ -0,0 +1,106 @@
import { NextFunction, Request, Response } from "express";
import { db } from "@server/db";
import { hash } from "@node-rs/argon2";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { newts } from "@server/db";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import { SqliteError } from "better-sqlite3";
import moment from "moment";
import { generateSessionToken } from "@server/auth/sessions/app";
import { createNewtSession } from "@server/auth/sessions/newt";
import { fromError } from "zod-validation-error";
import { hashPassword } from "@server/auth/password";
export const createNewtBodySchema = z.object({});
export type CreateNewtBody = z.infer<typeof createNewtBodySchema>;
export type CreateNewtResponse = {
token: string;
newtId: string;
secret: string;
};
const createNewtSchema = z
.object({
newtId: z.string(),
secret: z.string()
})
.strict();
export async function createNewt(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = createNewtSchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { newtId, secret } = parsedBody.data;
if (req.user && !req.userOrgRoleId) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);
}
const secretHash = await hashPassword(secret);
await db.insert(newts).values({
newtId: newtId,
secretHash,
dateCreated: moment().toISOString(),
});
// give the newt their default permissions:
// await db.insert(newtActions).values({
// newtId: newtId,
// actionId: ActionsEnum.createOrg,
// orgId: null,
// });
const token = generateSessionToken();
await createNewtSession(token, newtId);
return response<CreateNewtResponse>(res, {
data: {
newtId,
secret,
token,
},
success: true,
error: false,
message: "Newt created successfully",
status: HttpCode.OK,
});
} catch (e) {
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"A newt with that email address already exists"
)
);
} else {
console.error(e);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create newt"
)
);
}
}
}

View File

@@ -0,0 +1,119 @@
import { generateSessionToken } from "@server/auth/sessions/app";
import { db } from "@server/db";
import { olms } from "@server/db";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
import { eq } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import {
createOlmSession,
validateOlmSessionToken
} from "@server/auth/sessions/olm";
import { verifyPassword } from "@server/auth/password";
import logger from "@server/logger";
import config from "@server/lib/config";
export const olmGetTokenBodySchema = z.object({
olmId: z.string(),
secret: z.string(),
token: z.string().optional()
});
export type OlmGetTokenBody = z.infer<typeof olmGetTokenBodySchema>;
export async function getOlmToken(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
const parsedBody = olmGetTokenBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { olmId, secret, token } = parsedBody.data;
try {
if (token) {
const { session, olm } = await validateOlmSessionToken(token);
if (session) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
`Olm session already valid. Olm ID: ${olmId}. IP: ${req.ip}.`
);
}
return response<null>(res, {
data: null,
success: true,
error: false,
message: "Token session already valid",
status: HttpCode.OK
});
}
}
const existingOlmRes = await db
.select()
.from(olms)
.where(eq(olms.olmId, olmId));
if (!existingOlmRes || !existingOlmRes.length) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No olm found with that olmId"
)
);
}
const existingOlm = existingOlmRes[0];
const validSecret = await verifyPassword(
secret,
existingOlm.secretHash
);
if (!validSecret) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
`Olm id or secret is incorrect. Olm: ID ${olmId}. IP: ${req.ip}.`
);
}
return next(
createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect")
);
}
logger.debug("Creating new olm session token");
const resToken = generateSessionToken();
await createOlmSession(resToken, existingOlm.olmId);
logger.debug("Token created successfully");
return response<{ token: string }>(res, {
data: {
token: resToken
},
success: true,
error: false,
message: "Token created successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to authenticate olm"
)
);
}
}

View File

@@ -0,0 +1,93 @@
import { db } from "@server/db";
import { MessageHandler } from "../ws";
import { clients, Olm } from "@server/db";
import { eq, lt, isNull } from "drizzle-orm";
import logger from "@server/logger";
// Track if the offline checker interval is running
let offlineCheckerInterval: NodeJS.Timeout | null = null;
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
/**
* Starts the background interval that checks for clients that haven't pinged recently
* and marks them as offline
*/
export const startOfflineChecker = (): void => {
if (offlineCheckerInterval) {
return; // Already running
}
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = new Date(Date.now() - OFFLINE_THRESHOLD_MS);
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
await db
.update(clients)
.set({ online: false })
.where(
eq(clients.online, true) &&
(lt(clients.lastPing, twoMinutesAgo.toISOString()) || isNull(clients.lastPing))
);
} catch (error) {
logger.error("Error in offline checker interval", { error });
}
}, OFFLINE_CHECK_INTERVAL);
logger.info("Started offline checker interval");
}
/**
* Stops the background interval that checks for offline clients
*/
export const stopOfflineChecker = (): void => {
if (offlineCheckerInterval) {
clearInterval(offlineCheckerInterval);
offlineCheckerInterval = null;
logger.info("Stopped offline checker interval");
}
}
/**
* Handles ping messages from clients and responds with pong
*/
export const handleOlmPingMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
try {
// Update the client's last ping timestamp
await db
.update(clients)
.set({
lastPing: new Date().toISOString(),
online: true,
})
.where(eq(clients.clientId, olm.clientId));
} catch (error) {
logger.error("Error handling ping message", { error });
}
return {
message: {
type: "pong",
data: {
timestamp: new Date().toISOString(),
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -0,0 +1,181 @@
import { db } from "@server/db";
import { MessageHandler } from "../ws";
import {
clients,
clientSites,
exitNodes,
Olm,
olms,
sites
} from "@server/db";
import { eq, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
const now = new Date().getTime() / 1000;
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
const clientId = olm.clientId;
const { publicKey } = message.data;
if (!publicKey) {
logger.warn("Public key not provided");
return;
}
// Get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
if (client.exitNodeId) {
// Get the exit node for this site
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, client.exitNodeId))
.limit(1);
// Send holepunch message for each site
sendToClient(olm.olmId, {
type: "olm/wg/holepunch",
data: {
serverPubKey: exitNode.publicKey
}
});
}
if (now - (client.lastHolePunch || 0) > 6) {
logger.warn("Client last hole punch is too old, skipping all sites");
return;
}
if (client.pubKey !== publicKey) {
logger.info(
"Public key mismatch. Updating public key and clearing session info..."
);
// Update the client's public key
await db
.update(clients)
.set({
pubKey: publicKey
})
.where(eq(clients.clientId, olm.clientId));
// set isRelay to false for all of the client's sites to reset the connection metadata
await db
.update(clientSites)
.set({
isRelayed: false
})
.where(eq(clientSites.clientId, olm.clientId));
}
// Get all sites data
const sitesData = await db
.select()
.from(sites)
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
.where(eq(clientSites.clientId, client.clientId));
// Prepare an array to store site configurations
const siteConfigurations = [];
// Process each site
for (const { sites: site } of sitesData) {
if (!site.exitNodeId) {
logger.warn(
`Site ${site.siteId} does not have exit node, skipping`
);
continue;
}
// Validate endpoint and hole punch status
if (!site.endpoint) {
logger.warn(`Site ${site.siteId} has no endpoint, skipping`);
continue;
}
if (site.lastHolePunch && now - site.lastHolePunch > 6) {
logger.warn(
`Site ${site.siteId} last hole punch is too old, skipping`
);
continue;
}
// If public key changed, delete old peer from this site
if (client.pubKey && client.pubKey != publicKey) {
logger.info(
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
);
await deletePeer(site.siteId, client.pubKey!);
}
if (!site.subnet) {
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
continue;
}
// Add the peer to the exit node for this site
if (client.endpoint) {
logger.info(
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${client.endpoint}`
);
await addPeer(site.siteId, {
publicKey: publicKey,
allowedIps: [`${client.subnet.split('/')[0]}/32`], // we want to only allow from that client
endpoint: client.endpoint
});
} else {
logger.warn(
`Client ${client.clientId} has no endpoint, skipping peer addition`
);
}
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort
});
}
// If we have no valid site configurations, don't send a connect message
if (siteConfigurations.length === 0) {
logger.warn("No valid site configurations found");
return;
}
// Return connect message with all site configurations
return {
message: {
type: "olm/wg/connect",
data: {
sites: siteConfigurations,
tunnelIP: client.subnet
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -0,0 +1,58 @@
import { db } from "@server/db";
import { MessageHandler } from "../ws";
import { clients, clientSites, Olm } from "@server/db";
import { eq } from "drizzle-orm";
import { updatePeer } from "../newt/peers";
import logger from "@server/logger";
export const handleOlmRelayMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
logger.info("Handling relay olm message!");
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
return;
}
const clientId = olm.clientId;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Site not found or does not have exit node");
return;
}
// make sure we hand endpoints for both the site and the client and the lastHolePunch is not too old
if (!client.pubKey) {
logger.warn("Site or client has no endpoint or listen port");
return;
}
const { siteId } = message.data;
await db
.update(clientSites)
.set({
isRelayed: true
})
.where(eq(clientSites.clientId, olm.clientId));
// update the peer on the exit node
await updatePeer(siteId, client.pubKey, {
endpoint: "" // this removes the endpoint
});
return;
};

View File

@@ -0,0 +1,5 @@
export * from "./handleOlmRegisterMessage";
export * from "./getOlmToken";
export * from "./createOlm";
export * from "./handleOlmRelayMessage";
export * from "./handleOlmPingMessage";

View File

@@ -0,0 +1,92 @@
import { db } from "@server/db";
import { clients, olms, newts, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { sendToClient } from "../ws";
import logger from "@server/logger";
export async function addPeer(
clientId: number,
peer: {
siteId: number;
publicKey: string;
endpoint: string;
serverIP: string | null;
serverPort: number | null;
}
) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
sendToClient(olm.olmId, {
type: "olm/wg/peer/add",
data: {
siteId: peer.siteId,
publicKey: peer.publicKey,
endpoint: peer.endpoint,
serverIP: peer.serverIP,
serverPort: peer.serverPort
}
});
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`);
}
export async function deletePeer(clientId: number, siteId: number, publicKey: string) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
sendToClient(olm.olmId, {
type: "olm/wg/peer/remove",
data: {
publicKey,
siteId: siteId
}
});
logger.info(`Deleted peer ${publicKey} from olm ${olm.olmId}`);
}
export async function updatePeer(
clientId: number,
peer: {
siteId: number;
publicKey: string;
endpoint: string;
serverIP: string | null;
serverPort: number | null;
}
) {
const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
sendToClient(olm.olmId, {
type: "olm/wg/peer/update",
data: {
siteId: peer.siteId,
publicKey: peer.publicKey,
endpoint: peer.endpoint,
serverIP: peer.serverIP,
serverPort: peer.serverPort
}
});
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`);
}

View File

@@ -23,16 +23,16 @@ import config from "@server/lib/config";
import { fromError } from "zod-validation-error";
import { defaultRoleAllowedActions } from "../role";
import { OpenAPITags, registry } from "@server/openApi";
import { isValidCIDR } from "@server/lib/validators";
const createOrgSchema = z
.object({
orgId: z.string(),
name: z.string().min(1).max(255)
name: z.string().min(1).max(255),
subnet: z.string()
})
.strict();
// const MAX_ORGS = 5;
registry.registerPath({
method: "put",
path: "/org",
@@ -78,7 +78,32 @@ export async function createOrg(
);
}
const { orgId, name } = parsedBody.data;
const { orgId, name, subnet } = parsedBody.data;
if (!isValidCIDR(subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid subnet format. Please provide a valid CIDR notation."
)
);
}
// make sure the subnet is unique
const subnetExists = await db
.select()
.from(orgs)
.where(eq(orgs.subnet, subnet))
.limit(1);
if (subnetExists.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
// make sure the orgId is unique
const orgExists = await db
@@ -109,7 +134,8 @@ export async function createOrg(
.insert(orgs)
.values({
orgId,
name
name,
subnet
})
.returning();
@@ -142,25 +168,25 @@ export async function createOrg(
// Get all actions and create role actions
const actionIds = await trx.select().from(actions).execute();
if (actionIds.length > 0) {
await trx
.insert(roleActions)
.values(
actionIds.map((action) => ({
roleId,
actionId: action.actionId,
orgId: newOrg[0].orgId
}))
);
await trx.insert(roleActions).values(
actionIds.map((action) => ({
roleId,
actionId: action.actionId,
orgId: newOrg[0].orgId
}))
);
}
await trx.insert(orgDomains).values(
allDomains.map((domain) => ({
orgId: newOrg[0].orgId,
domainId: domain.domainId
}))
);
if (allDomains.length) {
await trx.insert(orgDomains).values(
allDomains.map((domain) => ({
orgId: newOrg[0].orgId,
domainId: domain.domainId
}))
);
}
if (req.user) {
await trx.insert(userOrgs).values({
@@ -187,7 +213,7 @@ export async function createOrg(
orgId: newOrg[0].orgId,
roleId: roleId,
isOwner: true
});
});
}
const memberRole = await trx

View File

@@ -89,6 +89,8 @@ export async function deleteOrg(
.where(eq(sites.orgId, orgId))
.limit(1);
const deletedNewtIds: string[] = [];
await db.transaction(async (trx) => {
if (sites) {
for (const site of orgSites) {
@@ -102,11 +104,7 @@ export async function deleteOrg(
.where(eq(newts.siteId, site.siteId))
.returning();
if (deletedNewt) {
const payload = {
type: `newt/terminate`,
data: {}
};
sendToClient(deletedNewt.newtId, payload);
deletedNewtIds.push(deletedNewt.newtId);
// delete all of the sessions for the newt
await trx
@@ -131,6 +129,18 @@ export async function deleteOrg(
await trx.delete(orgs).where(eq(orgs.orgId, orgId));
});
// Send termination messages outside of transaction to prevent blocking
for (const newtId of deletedNewtIds) {
const payload = {
type: `newt/terminate`,
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(newtId, payload).catch(error => {
logger.error("Failed to send termination message to newt:", error);
});
}
return response(res, {
data: null,
success: true,

View File

@@ -6,3 +6,4 @@ export * from "./listUserOrgs";
export * from "./checkId";
export * from "./getOrgOverview";
export * from "./listOrgs";
export * from "./pickOrgDefaults";

View File

@@ -5,7 +5,7 @@ import { Org, orgs, userOrgs } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { sql, inArray, eq } from "drizzle-orm";
import { sql, inArray, eq, and } from "drizzle-orm";
import logger from "@server/logger";
import { fromZodError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
@@ -40,8 +40,10 @@ const listOrgsSchema = z.object({
// responses: {}
// });
type ResponseOrg = Org & { isOwner?: boolean };
export type ListUserOrgsResponse = {
orgs: Org[];
orgs: ResponseOrg[];
pagination: { total: number; limit: number; offset: number };
};
@@ -106,6 +108,10 @@ export async function listUserOrgs(
.select()
.from(orgs)
.where(inArray(orgs.orgId, userOrgIds))
.leftJoin(
userOrgs,
and(eq(userOrgs.orgId, orgs.orgId), eq(userOrgs.userId, userId))
)
.limit(limit)
.offset(offset);
@@ -115,9 +121,19 @@ export async function listUserOrgs(
.where(inArray(orgs.orgId, userOrgIds));
const totalCount = totalCountResult[0].count;
const responseOrgs = organizations.map((val) => {
const res = {
...val.orgs
} as ResponseOrg;
if (val.userOrgs && val.userOrgs.isOwner) {
res.isOwner = val.userOrgs.isOwner;
}
return res;
});
return response<ListUserOrgsResponse>(res, {
data: {
orgs: organizations,
orgs: responseOrgs,
pagination: {
total: totalCount,
limit,

View File

@@ -0,0 +1,39 @@
import { Request, Response, NextFunction } from "express";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { getNextAvailableOrgSubnet } from "@server/lib/ip";
import config from "@server/lib/config";
export type PickOrgDefaultsResponse = {
subnet: string;
};
export async function pickOrgDefaults(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
// TODO: Why would each org have to have its own subnet?
// const subnet = await getNextAvailableOrgSubnet();
// Just hard code the subnet for now for everyone
const subnet = config.getRawConfig().orgs.subnet_group;
return response<PickOrgDefaultsResponse>(res, {
data: {
subnet: subnet
},
success: true,
error: false,
message: "Organization defaults created successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -21,6 +21,7 @@ import logger from "@server/logger";
import { subdomainSchema } from "@server/lib/schemas";
import config from "@server/lib/config";
import { OpenAPITags, registry } from "@server/openApi";
import { build } from "@server/build";
const createResourceParamsSchema = z
.object({
@@ -36,7 +37,6 @@ const createHttpResourceSchema = z
.string()
.optional()
.transform((val) => val?.toLowerCase()),
isBaseDomain: z.boolean().optional(),
siteId: z.number(),
http: z.boolean(),
protocol: z.enum(["tcp", "udp"]),
@@ -52,19 +52,6 @@ const createHttpResourceSchema = z
},
{ message: "Invalid subdomain" }
)
.refine(
(data) => {
if (!config.getRawConfig().flags?.allow_base_domain_resources) {
if (data.isBaseDomain) {
return false;
}
}
return true;
},
{
message: "Base domain resources are not allowed"
}
);
const createRawResourceSchema = z
.object({
@@ -101,9 +88,12 @@ registry.registerPath({
body: {
content: {
"application/json": {
schema: createHttpResourceSchema.or(
createRawResourceSchema
)
schema:
build == "oss"
? createHttpResourceSchema.or(
createRawResourceSchema
)
: createHttpResourceSchema
}
}
}
@@ -166,6 +156,14 @@ export async function createResource(
{ siteId, orgId }
);
} else {
if (!config.getRawConfig().flags?.allow_raw_resources && build == "oss") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Raw resources are not allowed"
)
);
}
return await createRawResource(
{ req, res, next },
{ siteId, orgId }
@@ -203,35 +201,81 @@ async function createHttpResource(
);
}
const { name, subdomain, isBaseDomain, http, protocol, domainId } =
parsedBody.data;
const { name, subdomain, domainId } = parsedBody.data;
const [orgDomain] = await db
const [domainRes] = await db
.select()
.from(orgDomains)
.where(
.from(domains)
.where(eq(domains.domainId, domainId))
.leftJoin(
orgDomains,
and(eq(orgDomains.orgId, orgId), eq(orgDomains.domainId, domainId))
)
.leftJoin(domains, eq(orgDomains.domainId, domains.domainId));
);
if (!orgDomain || !orgDomain.domains) {
if (!domainRes || !domainRes.domains) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Domain with ID ${parsedBody.data.domainId} not found`
`Domain with ID ${domainId} not found`
)
);
}
const domain = orgDomain.domains;
if (domainRes.orgDomains && domainRes.orgDomains.orgId !== orgId) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
`Organization does not have access to domain with ID ${domainId}`
)
);
}
if (!domainRes.domains.verified) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Domain with ID ${domainRes.domains.domainId} is not verified`
)
);
}
let fullDomain = "";
if (isBaseDomain) {
fullDomain = domain.baseDomain;
} else {
fullDomain = `${subdomain}.${domain.baseDomain}`;
if (domainRes.domains.type == "ns") {
if (subdomain) {
fullDomain = `${subdomain}.${domainRes.domains.baseDomain}`;
} else {
fullDomain = domainRes.domains.baseDomain;
}
} else if (domainRes.domains.type == "cname") {
fullDomain = domainRes.domains.baseDomain;
} else if (domainRes.domains.type == "wildcard") {
if (subdomain) {
// the subdomain cant have a dot in it
const parsedSubdomain = subdomainSchema.safeParse(subdomain);
if (!parsedSubdomain.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedSubdomain.error).toString()
)
);
}
if (parsedSubdomain.data.includes(".")) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Subdomain cannot contain a dot when using wildcard domains"
)
);
}
fullDomain = `${subdomain}.${domainRes.domains.baseDomain}`;
} else {
fullDomain = domainRes.domains.baseDomain;
}
}
fullDomain = fullDomain.toLowerCase();
logger.debug(`Full domain: ${fullDomain}`);
// make sure the full domain is unique
@@ -261,10 +305,10 @@ async function createHttpResource(
orgId,
name,
subdomain,
http,
protocol,
http: true,
protocol: "tcp",
ssl: true,
isBaseDomain
isBaseDomain: false
})
.returning();

View File

@@ -69,7 +69,8 @@ function queryResources(
http: resources.http,
protocol: resources.protocol,
proxyPort: resources.proxyPort,
enabled: resources.enabled
enabled: resources.enabled,
domainId: resources.domainId
})
.from(resources)
.leftJoin(sites, eq(resources.siteId, sites.siteId))
@@ -103,7 +104,8 @@ function queryResources(
http: resources.http,
protocol: resources.protocol,
proxyPort: resources.proxyPort,
enabled: resources.enabled
enabled: resources.enabled,
domainId: resources.domainId
})
.from(resources)
.leftJoin(sites, eq(resources.siteId, sites.siteId))

View File

@@ -20,6 +20,7 @@ import { tlsNameSchema } from "@server/lib/schemas";
import { subdomainSchema } from "@server/lib/schemas";
import { registry } from "@server/openApi";
import { OpenAPITags } from "@server/openApi";
import { build } from "@server/build";
const updateResourceParamsSchema = z
.object({
@@ -40,7 +41,6 @@ const updateHttpResourceBodySchema = z
sso: z.boolean().optional(),
blockAccess: z.boolean().optional(),
emailWhitelistEnabled: z.boolean().optional(),
isBaseDomain: z.boolean().optional(),
applyRules: z.boolean().optional(),
domainId: z.string().optional(),
enabled: z.boolean().optional(),
@@ -61,19 +61,6 @@ const updateHttpResourceBodySchema = z
},
{ message: "Invalid subdomain" }
)
.refine(
(data) => {
if (!config.getRawConfig().flags?.allow_base_domain_resources) {
if (data.isBaseDomain) {
return false;
}
}
return true;
},
{
message: "Base domain resources are not allowed"
}
)
.refine(
(data) => {
if (data.tlsServerName) {
@@ -134,9 +121,12 @@ registry.registerPath({
body: {
content: {
"application/json": {
schema: updateHttpResourceBodySchema.and(
updateRawResourceBodySchema
)
schema:
build == "oss"
? updateHttpResourceBodySchema.and(
updateRawResourceBodySchema
)
: updateHttpResourceBodySchema
}
}
}
@@ -242,86 +232,120 @@ async function updateHttpResource(
const updateData = parsedBody.data;
if (updateData.domainId) {
const [existingDomain] = await db
.select()
.from(orgDomains)
.where(
and(
eq(orgDomains.orgId, org.orgId),
eq(orgDomains.domainId, updateData.domainId)
)
)
.leftJoin(domains, eq(orgDomains.domainId, domains.domainId));
const domainId = updateData.domainId;
if (!existingDomain) {
const [domainRes] = await db
.select()
.from(domains)
.where(eq(domains.domainId, domainId))
.leftJoin(
orgDomains,
and(
eq(orgDomains.orgId, resource.orgId),
eq(orgDomains.domainId, domainId)
)
);
if (!domainRes || !domainRes.domains) {
return next(
createHttpError(HttpCode.NOT_FOUND, `Domain not found`)
createHttpError(
HttpCode.NOT_FOUND,
`Domain with ID ${updateData.domainId} not found`
)
);
}
}
const domainId = updateData.domainId || resource.domainId!;
const subdomain = updateData.subdomain || resource.subdomain;
const [domain] = await db
.select()
.from(domains)
.where(eq(domains.domainId, domainId));
const isBaseDomain =
updateData.isBaseDomain !== undefined
? updateData.isBaseDomain
: resource.isBaseDomain;
let fullDomain: string | null = null;
if (isBaseDomain) {
fullDomain = domain.baseDomain;
} else if (subdomain && domain) {
fullDomain = `${subdomain}.${domain.baseDomain}`;
}
if (fullDomain) {
const [existingDomain] = await db
.select()
.from(resources)
.where(eq(resources.fullDomain, fullDomain));
if (
existingDomain &&
existingDomain.resourceId !== resource.resourceId
domainRes.orgDomains &&
domainRes.orgDomains.orgId !== resource.orgId
) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Resource with that domain already exists"
HttpCode.FORBIDDEN,
`You do not have permission to use domain with ID ${updateData.domainId}`
)
);
}
}
const updatePayload = {
...updateData,
fullDomain
};
if (!domainRes.domains.verified) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Domain with ID ${updateData.domainId} is not verified`
)
);
}
let fullDomain = "";
if (domainRes.domains.type == "ns") {
if (updateData.subdomain) {
fullDomain = `${updateData.subdomain}.${domainRes.domains.baseDomain}`;
} else {
fullDomain = domainRes.domains.baseDomain;
}
} else if (domainRes.domains.type == "cname") {
fullDomain = domainRes.domains.baseDomain;
} else if (domainRes.domains.type == "wildcard") {
if (updateData.subdomain) {
// the subdomain cant have a dot in it
const parsedSubdomain = subdomainSchema.safeParse(updateData.subdomain);
if (!parsedSubdomain.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedSubdomain.error).toString()
)
);
}
if (parsedSubdomain.data.includes(".")) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Subdomain cannot contain a dot when using wildcard domains"
)
);
}
fullDomain = `${updateData.subdomain}.${domainRes.domains.baseDomain}`;
} else {
fullDomain = domainRes.domains.baseDomain;
}
}
fullDomain = fullDomain.toLowerCase();
logger.debug(`Full domain: ${fullDomain}`);
if (fullDomain) {
const [existingDomain] = await db
.select()
.from(resources)
.where(eq(resources.fullDomain, fullDomain));
if (
existingDomain &&
existingDomain.resourceId !== resource.resourceId
) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Resource with that domain already exists"
)
);
}
}
// update the full domain if it has changed
if (fullDomain && fullDomain !== resource.fullDomain) {
await db
.update(resources)
.set({ fullDomain })
.where(eq(resources.resourceId, resource.resourceId));
}
}
const updatedResource = await db
.update(resources)
.set({
name: updatePayload.name,
subdomain: updatePayload.subdomain,
ssl: updatePayload.ssl,
sso: updatePayload.sso,
blockAccess: updatePayload.blockAccess,
emailWhitelistEnabled: updatePayload.emailWhitelistEnabled,
isBaseDomain: updatePayload.isBaseDomain,
applyRules: updatePayload.applyRules,
domainId: updatePayload.domainId,
enabled: updatePayload.enabled,
stickySession: updatePayload.stickySession,
tlsServerName: updatePayload.tlsServerName,
setHostHeader: updatePayload.setHostHeader,
fullDomain: updatePayload.fullDomain
})
.set(updateData)
.where(eq(resources.resourceId, resource.resourceId))
.returning();

View File

@@ -14,6 +14,9 @@ import { newts } from "@server/db";
import moment from "moment";
import { OpenAPITags, registry } from "@server/openApi";
import { hashPassword } from "@server/auth/password";
import { isValidIP } from "@server/lib/validators";
import { isIpInCidr } from "@server/lib/ip";
import config from "@server/lib/config";
const createSiteParamsSchema = z
.object({
@@ -35,9 +38,18 @@ const createSiteSchema = z
subnet: z.string().optional(),
newtId: z.string().optional(),
secret: z.string().optional(),
address: z.string().optional(),
type: z.enum(["newt", "wireguard", "local"])
})
.strict();
.strict()
.refine((data) => {
if (data.type === "local") {
return !config.getRawConfig().flags?.disable_local_sites;
} else if (data.type === "wireguard") {
return !config.getRawConfig().flags?.disable_basic_wireguard_sites;
}
return true;
});
export type CreateSiteBody = z.infer<typeof createSiteSchema>;
@@ -84,7 +96,8 @@ export async function createSite(
pubKey,
subnet,
newtId,
secret
secret,
address
} = parsedBody.data;
const parsedParams = createSiteParamsSchema.safeParse(req.params);
@@ -116,6 +129,59 @@ export async function createSite(
);
}
let updatedAddress = null;
if (address) {
if (!isValidIP(address)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid subnet format. Please provide a valid CIDR notation."
)
);
}
if (!isIpInCidr(address, org.subnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"IP is not in the CIDR range of the subnet."
)
);
}
updatedAddress = `${address}/${org.subnet.split("/")[1]}`; // we want the block size of the whole org
// make sure the subnet is unique
const addressExistsSites = await db
.select()
.from(sites)
.where(eq(sites.address, updatedAddress))
.limit(1);
if (addressExistsSites.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
const addressExistsClients = await db
.select()
.from(sites)
.where(eq(sites.subnet, updatedAddress))
.limit(1);
if (addressExistsClients.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
`Subnet ${subnet} already exists`
)
);
}
}
const niceId = await getUniqueSiteName(orgId);
await db.transaction(async (trx) => {
@@ -139,6 +205,7 @@ export async function createSite(
exitNodeId,
name,
niceId,
// address: updatedAddress || null,
subnet,
type,
dockerSocketEnabled: type == "newt",
@@ -154,6 +221,7 @@ export async function createSite(
orgId,
name,
niceId,
// address: updatedAddress || null,
type,
dockerSocketEnabled: type == "newt",
subnet: "0.0.0.0/0"

View File

@@ -62,6 +62,8 @@ export async function deleteSite(
);
}
let deletedNewtId: string | null = null;
await db.transaction(async (trx) => {
if (site.pubKey) {
if (site.type == "wireguard") {
@@ -73,11 +75,7 @@ export async function deleteSite(
.where(eq(newts.siteId, siteId))
.returning();
if (deletedNewt) {
const payload = {
type: `newt/terminate`,
data: {}
};
sendToClient(deletedNewt.newtId, payload);
deletedNewtId = deletedNewt.newtId;
// delete all of the sessions for the newt
await trx
@@ -90,6 +88,18 @@ export async function deleteSite(
await trx.delete(sites).where(eq(sites.siteId, siteId));
});
// Send termination message outside of transaction to prevent blocking
if (deletedNewtId) {
const payload = {
type: `newt/terminate`,
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(deletedNewtId, payload).catch(error => {
logger.error("Failed to send termination message to newt:", error);
});
}
return response(res, {
data: null,
success: true,

View File

@@ -1,4 +1,4 @@
import { db } from "@server/db";
import { db, newts } from "@server/db";
import { orgs, roleSites, sites, userSites } from "@server/db";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
@@ -9,6 +9,42 @@ import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import NodeCache from "node-cache";
import semver from "semver";
const newtVersionCache = new NodeCache({ stdTTL: 3600 }); // 1 hours in seconds
async function getLatestNewtVersion(): Promise<string | null> {
try {
const cachedVersion = newtVersionCache.get<string>("latestNewtVersion");
if (cachedVersion) {
return cachedVersion;
}
const response = await fetch(
"https://api.github.com/repos/fosrl/newt/tags"
);
if (!response.ok) {
logger.warn("Failed to fetch latest Newt version from GitHub");
return null;
}
const tags = await response.json();
if (!Array.isArray(tags) || tags.length === 0) {
logger.warn("No tags found for Newt repository");
return null;
}
const latestVersion = tags[0].name;
newtVersionCache.set("latestNewtVersion", latestVersion);
return latestVersion;
} catch (error) {
logger.error("Error fetching latest Newt version:", error);
return null;
}
}
const listSitesParamsSchema = z
.object({
@@ -43,10 +79,13 @@ function querySites(orgId: string, accessibleSiteIds: number[]) {
megabytesOut: sites.megabytesOut,
orgName: orgs.name,
type: sites.type,
online: sites.online
online: sites.online,
address: sites.address,
newtVersion: newts.version
})
.from(sites)
.leftJoin(orgs, eq(sites.orgId, orgs.orgId))
.leftJoin(newts, eq(newts.siteId, sites.siteId))
.where(
and(
inArray(sites.siteId, accessibleSiteIds),
@@ -55,8 +94,12 @@ function querySites(orgId: string, accessibleSiteIds: number[]) {
);
}
type SiteWithUpdateAvailable = Awaited<ReturnType<typeof querySites>>[0] & {
newtUpdateAvailable?: boolean;
};
export type ListSitesResponse = {
sites: Awaited<ReturnType<typeof querySites>>;
sites: SiteWithUpdateAvailable[];
pagination: { total: number; limit: number; offset: number };
};
@@ -147,9 +190,36 @@ export async function listSites(
const totalCountResult = await countQuery;
const totalCount = totalCountResult[0].count;
const latestNewtVersion = await getLatestNewtVersion();
const sitesWithUpdates: SiteWithUpdateAvailable[] = sitesList.map(
(site) => {
const siteWithUpdate: SiteWithUpdateAvailable = { ...site };
if (
site.type === "newt" &&
site.newtVersion &&
latestNewtVersion
) {
try {
siteWithUpdate.newtUpdateAvailable = semver.lt(
site.newtVersion,
latestNewtVersion
);
} catch (error) {
siteWithUpdate.newtUpdateAvailable = false;
}
} else {
siteWithUpdate.newtUpdateAvailable = false;
}
return siteWithUpdate;
}
);
return response<ListSitesResponse>(res, {
data: {
sites: sitesList,
sites: sitesWithUpdates,
pagination: {
total: totalCount,
limit,

View File

@@ -6,10 +6,11 @@ import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { findNextAvailableCidr } from "@server/lib/ip";
import { findNextAvailableCidr, getNextAvailableClientSubnet } from "@server/lib/ip";
import { generateId } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import { OpenAPITags, registry } from "@server/openApi";
import { fromError } from "zod-validation-error";
import { z } from "zod";
export type PickSiteDefaultsResponse = {
@@ -19,9 +20,10 @@ export type PickSiteDefaultsResponse = {
name: string;
listenPort: number;
endpoint: string;
subnet: string;
subnet: string; // TODO: make optional?
newtId: string;
newtSecret: string;
clientAddress?: string;
};
registry.registerPath({
@@ -38,12 +40,29 @@ registry.registerPath({
responses: {}
});
const pickSiteDefaultsSchema = z
.object({
orgId: z.string()
})
.strict();
export async function pickSiteDefaults(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = pickSiteDefaultsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
// TODO: more intelligent way to pick the exit node
// make sure there is an exit node by counting the exit nodes table
@@ -67,7 +86,7 @@ export async function pickSiteDefaults(
.where(eq(sites.exitNodeId, exitNode.exitNodeId));
// TODO: we need to lock this subnet for some time so someone else does not take it
let subnets = sitesQuery.map((site) => site.subnet);
let subnets = sitesQuery.map((site) => site.subnet).filter((subnet) => subnet !== null);
// exclude the exit node address by replacing after the / with a site block size
subnets.push(
exitNode.address.replace(
@@ -89,6 +108,18 @@ export async function pickSiteDefaults(
);
}
const newClientAddress = await getNextAvailableClientSubnet(orgId);
if (!newClientAddress) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"No available subnet found"
)
);
}
const clientAddress = newClientAddress.split("/")[0];
const newtId = generateId(15);
const secret = generateId(48);
@@ -100,7 +131,9 @@ export async function pickSiteDefaults(
name: exitNode.name,
listenPort: exitNode.listenPort,
endpoint: exitNode.endpoint,
// subnet: `${newSubnet.split("/")[0]}/${config.getRawConfig().gerbil.block_size}`, // we want the block size of the whole subnet
subnet: newSubnet,
clientAddress: clientAddress,
newtId,
newtSecret: secret
},

View File

@@ -1,11 +1,12 @@
import { Request, Response } from "express";
import { db } from "@server/db";
import { db, exitNodes } from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
import config from "@server/lib/config";
import { orgs, resources, sites, Target, targets } from "@server/db";
import { sql } from "drizzle-orm";
let currentExitNodeId: number;
export async function traefikConfigProvider(
_: Request,
@@ -15,16 +16,48 @@ export async function traefikConfigProvider(
// Get all resources with related data
const allResources = await db.transaction(async (tx) => {
// First query to get resources with site and org info
// Get the current exit node name from config
if (!currentExitNodeId) {
if (config.getRawConfig().gerbil.exit_node_name) {
const exitNodeName =
config.getRawConfig().gerbil.exit_node_name!;
const [exitNode] = await tx
.select({
exitNodeId: exitNodes.exitNodeId
})
.from(exitNodes)
.where(eq(exitNodes.name, exitNodeName));
if (!exitNode) {
logger.error(
`Exit node with name ${exitNodeName} not found in the database`
);
return [];
}
currentExitNodeId = exitNode.exitNodeId;
} else {
const [exitNode] = await tx
.select({
exitNodeId: exitNodes.exitNodeId
})
.from(exitNodes)
.limit(1);
if (!exitNode) {
logger.error("No exit node found in the database");
return [];
}
currentExitNodeId = exitNode.exitNodeId;
}
}
// Get the site(s) on this exit node
const resourcesWithRelations = await tx
.select({
// Resource fields
resourceId: resources.resourceId,
subdomain: resources.subdomain,
fullDomain: resources.fullDomain,
ssl: resources.ssl,
blockAccess: resources.blockAccess,
sso: resources.sso,
emailWhitelistEnabled: resources.emailWhitelistEnabled,
http: resources.http,
proxyPort: resources.proxyPort,
protocol: resources.protocol,
@@ -34,11 +67,8 @@ export async function traefikConfigProvider(
site: {
siteId: sites.siteId,
type: sites.type,
subnet: sites.subnet
},
// Org fields
org: {
orgId: orgs.orgId
subnet: sites.subnet,
exitNodeId: sites.exitNodeId,
},
enabled: resources.enabled,
stickySession: resources.stickySession,
@@ -47,7 +77,7 @@ export async function traefikConfigProvider(
})
.from(resources)
.innerJoin(sites, eq(sites.siteId, resources.siteId))
.innerJoin(orgs, eq(resources.orgId, orgs.orgId));
.where(eq(sites.exitNodeId, currentExitNodeId));
// Get all resource IDs from the first query
const resourceIds = resourcesWithRelations.map((r) => r.resourceId);
@@ -140,7 +170,6 @@ export async function traefikConfigProvider(
for (const resource of allResources) {
const targets = resource.targets as Target[];
const site = resource.site;
const org = resource.org;
const routerName = `${resource.resourceId}-router`;
const serviceName = `${resource.resourceId}-service`;
@@ -164,11 +193,6 @@ export async function traefikConfigProvider(
continue;
}
// HTTP configuration remains the same
if (!resource.subdomain && !resource.isBaseDomain) {
continue;
}
// add routers and services empty objects if they don't exist
if (!config_output.http.routers) {
config_output.http.routers = {};
@@ -192,26 +216,22 @@ export async function traefikConfigProvider(
const configDomain = config.getDomain(resource.domainId);
if (!configDomain) {
logger.error(
`Failed to get domain from config for resource ${resource.resourceId}`
);
continue;
let tls = {};
if (configDomain) {
tls = {
certResolver: configDomain.cert_resolver,
...(configDomain.prefer_wildcard_cert
? {
domains: [
{
main: wildCard
}
]
}
: {})
};
}
const tls = {
certResolver: configDomain.cert_resolver,
...(configDomain.prefer_wildcard_cert
? {
domains: [
{
main: wildCard
}
]
}
: {})
};
const additionalMiddlewares =
config.getRawConfig().traefik.additional_middlewares || [];
@@ -227,6 +247,7 @@ export async function traefikConfigProvider(
],
service: serviceName,
rule: `Host(\`${fullDomain}\`)`,
priority: 100,
...(resource.ssl ? { tls } : {})
};
@@ -237,7 +258,8 @@ export async function traefikConfigProvider(
],
middlewares: [redirectHttpsMiddlewareName],
service: serviceName,
rule: `Host(\`${fullDomain}\`)`
rule: `Host(\`${fullDomain}\`)`,
priority: 100
};
}
@@ -262,7 +284,7 @@ export async function traefikConfigProvider(
} else if (site.type === "newt") {
if (
!target.internalPort ||
!target.method
!target.method || !site.subnet
) {
return false;
}
@@ -278,7 +300,7 @@ export async function traefikConfigProvider(
url: `${target.method}://${target.ip}:${target.port}`
};
} else if (site.type === "newt") {
const ip = site.subnet.split("/")[0];
const ip = site.subnet!.split("/")[0];
return {
url: `${target.method}://${ip}:${target.internalPort}`
};
@@ -309,7 +331,9 @@ export async function traefikConfigProvider(
// if defined in the static config and here. if not set, self-signed certs won't work
insecureSkipVerify: true
};
config_output.http.services![serviceName].loadBalancer.serversTransport = transportName;
config_output.http.services![
serviceName
].loadBalancer.serversTransport = transportName;
}
// Add the host header middleware
@@ -317,23 +341,22 @@ export async function traefikConfigProvider(
if (!config_output.http.middlewares) {
config_output.http.middlewares = {};
}
config_output.http.middlewares[hostHeaderMiddlewareName] =
{
headers: {
customRequestHeaders: {
Host: resource.setHostHeader
}
config_output.http.middlewares[hostHeaderMiddlewareName] = {
headers: {
customRequestHeaders: {
Host: resource.setHostHeader
}
};
}
};
if (!config_output.http.routers![routerName].middlewares) {
config_output.http.routers![routerName].middlewares = [];
config_output.http.routers![routerName].middlewares =
[];
}
config_output.http.routers![routerName].middlewares = [
...config_output.http.routers![routerName].middlewares,
hostHeaderMiddlewareName
];
}
} else {
// Non-HTTP (TCP/UDP) configuration
const protocol = resource.protocol.toLowerCase();
@@ -371,7 +394,7 @@ export async function traefikConfigProvider(
return false;
}
} else if (site.type === "newt") {
if (!target.internalPort) {
if (!target.internalPort || !site.subnet) {
return false;
}
}
@@ -386,7 +409,7 @@ export async function traefikConfigProvider(
address: `${target.ip}:${target.port}`
};
} else if (site.type === "newt") {
const ip = site.subnet.split("/")[0];
const ip = site.subnet!.split("/")[0];
return {
address: `${ip}:${target.internalPort}`
};

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { db, UserOrg } from "@server/db";
import { roles, userInvites, userOrgs, users } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
@@ -92,6 +92,7 @@ export async function acceptInvite(
}
let roleId: number;
let totalUsers: UserOrg[] | undefined;
// get the role to make sure it exists
const existingRole = await db
.select()
@@ -122,6 +123,12 @@ export async function acceptInvite(
await trx
.delete(userInvites)
.where(eq(userInvites.inviteId, inviteId));
// Get the total number of users in the org now
totalUsers = await db
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, existingInvite.orgId));
});
return response<AcceptInviteResponse>(res, {

View File

@@ -6,7 +6,7 @@ import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { db } from "@server/db";
import { db, UserOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db";
import { generateId } from "@server/auth/sessions/app";
@@ -135,65 +135,76 @@ export async function createOrgUser(
);
}
const [existingUser] = await db
.select()
.from(users)
.where(eq(users.username, username));
let orgUsers: UserOrg[] | undefined;
if (existingUser) {
const [existingOrgUser] = await db
await db.transaction(async (trx) => {
const [existingUser] = await trx
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.userId, existingUser.userId)
)
);
.from(users)
.where(eq(users.username, username));
if (existingOrgUser) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"User already exists in this organization"
)
);
if (existingUser) {
const [existingOrgUser] = await trx
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.userId, existingUser.userId)
)
);
if (existingOrgUser) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"User already exists in this organization"
)
);
}
await trx
.insert(userOrgs)
.values({
orgId,
userId: existingUser.userId,
roleId: role.roleId
})
.returning();
} else {
const userId = generateId(15);
const [newUser] = await trx
.insert(users)
.values({
userId: userId,
email,
username,
name,
type: "oidc",
idpId,
dateCreated: new Date().toISOString(),
emailVerified: true
})
.returning();
await trx
.insert(userOrgs)
.values({
orgId,
userId: newUser.userId,
roleId: role.roleId
})
.returning();
}
await db
.insert(userOrgs)
.values({
orgId,
userId: existingUser.userId,
roleId: role.roleId
})
.returning();
} else {
const userId = generateId(15);
// List all of the users in the org
orgUsers = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, orgId));
});
const [newUser] = await db
.insert(users)
.values({
userId: userId,
email,
username,
name,
type: "oidc",
idpId,
dateCreated: new Date().toISOString(),
emailVerified: true
})
.returning();
await db
.insert(userOrgs)
.values({
orgId,
userId: newUser.userId,
roleId: role.roleId
})
.returning();
}
} else {
return next(
createHttpError(HttpCode.BAD_REQUEST, "User type is required")

View File

@@ -99,6 +99,7 @@ export async function inviteUser(
regenerate
} = parsedBody.data;
// Check if the organization exists
const org = await db
.select()

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, resources, sites } from "@server/db";
import { db, resources, sites, UserOrg } from "@server/db";
import { userOrgs, userResources, users, userSites } from "@server/db";
import { and, eq, exists } from "drizzle-orm";
import response from "@server/lib/response";
@@ -65,6 +65,8 @@ export async function removeUserOrg(
);
}
let userCount: UserOrg[] | undefined;
await db.transaction(async (trx) => {
await trx
.delete(userOrgs)
@@ -108,6 +110,11 @@ export async function removeUserOrg(
)
)
);
userCount = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, orgId));
});
return response(res, {

View File

@@ -3,25 +3,33 @@ import { Server as HttpServer } from "http";
import { WebSocket, WebSocketServer } from "ws";
import { IncomingMessage } from "http";
import { Socket } from "net";
import { Newt, newts, NewtSession } from "@server/db";
import { Newt, newts, NewtSession, olms, Olm, OlmSession } from "@server/db";
import { eq } from "drizzle-orm";
import { db } from "@server/db";
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
import { messageHandlers } from "./messageHandlers";
import logger from "@server/logger";
import redisManager from "@server/db/redis";
import { v4 as uuidv4 } from "uuid";
// Custom interfaces
interface WebSocketRequest extends IncomingMessage {
token?: string;
}
type ClientType = 'newt' | 'olm';
interface AuthenticatedWebSocket extends WebSocket {
newt?: Newt;
client?: Newt | Olm;
clientType?: ClientType;
connectionId?: string;
}
interface TokenPayload {
newt: Newt;
session: NewtSession;
client: Newt | Olm;
session: NewtSession | OlmSession;
clientType: ClientType;
}
interface WSMessage {
@@ -33,55 +41,121 @@ interface HandlerResponse {
message: WSMessage;
broadcast?: boolean;
excludeSender?: boolean;
targetNewtId?: string;
targetClientId?: string;
}
interface HandlerContext {
message: WSMessage;
senderWs: WebSocket;
newt: Newt | undefined;
sendToClient: (newtId: string, message: WSMessage) => boolean;
broadcastToAllExcept: (message: WSMessage, excludeNewtId?: string) => void;
client: Newt | Olm | undefined;
clientType: ClientType;
sendToClient: (clientId: string, message: WSMessage) => Promise<boolean>;
broadcastToAllExcept: (message: WSMessage, excludeClientId?: string) => Promise<void>;
connectedClients: Map<string, WebSocket[]>;
}
interface RedisMessage {
type: 'direct' | 'broadcast';
targetClientId?: string;
excludeClientId?: string;
message: WSMessage;
fromNodeId: string;
}
export type MessageHandler = (context: HandlerContext) => Promise<HandlerResponse | void>;
const router: Router = Router();
const wss: WebSocketServer = new WebSocketServer({ noServer: true });
// Client tracking map
// Generate unique node ID for this instance
const NODE_ID = uuidv4();
const REDIS_CHANNEL = 'websocket_messages';
// Client tracking map (local to this node)
let connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
// Helper to get map key
const getClientMapKey = (clientId: string) => clientId;
// Redis keys (generalized)
const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`;
const getNodeConnectionsKey = (nodeId: string, clientId: string) => `ws:node:${nodeId}:${clientId}`;
// Initialize Redis subscription for cross-node messaging
const initializeRedisSubscription = async (): Promise<void> => {
if (!redisManager.isRedisEnabled()) return;
await redisManager.subscribe(REDIS_CHANNEL, async (channel: string, message: string) => {
try {
const redisMessage: RedisMessage = JSON.parse(message);
// Ignore messages from this node
if (redisMessage.fromNodeId === NODE_ID) return;
if (redisMessage.type === 'direct' && redisMessage.targetClientId) {
// Send to specific client on this node
await sendToClientLocal(redisMessage.targetClientId, redisMessage.message);
} else if (redisMessage.type === 'broadcast') {
// Broadcast to all clients on this node except excluded
await broadcastToAllExceptLocal(redisMessage.message, redisMessage.excludeClientId);
}
} catch (error) {
logger.error('Error processing Redis message:', error);
}
});
};
// Helper functions for client management
const addClient = (newtId: string, ws: AuthenticatedWebSocket): void => {
const existingClients = connectedClients.get(newtId) || [];
const addClient = async (clientType: ClientType, clientId: string, ws: AuthenticatedWebSocket): Promise<void> => {
// Generate unique connection ID
const connectionId = uuidv4();
ws.connectionId = connectionId;
// Add to local tracking
const mapKey = getClientMapKey(clientId);
const existingClients = connectedClients.get(mapKey) || [];
existingClients.push(ws);
connectedClients.set(newtId, existingClients);
logger.info(`Client added to tracking - Newt ID: ${newtId}, Total connections: ${existingClients.length}`);
connectedClients.set(mapKey, existingClients);
// Add to Redis tracking if enabled
if (redisManager.isRedisEnabled()) {
await redisManager.sadd(getConnectionsKey(clientId), NODE_ID);
await redisManager.hset(getNodeConnectionsKey(NODE_ID, clientId), connectionId, Date.now().toString());
}
logger.info(`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Connection ID: ${connectionId}, Total connections: ${existingClients.length}`);
};
const removeClient = (newtId: string, ws: AuthenticatedWebSocket): void => {
const existingClients = connectedClients.get(newtId) || [];
const removeClient = async (clientType: ClientType, clientId: string, ws: AuthenticatedWebSocket): Promise<void> => {
const mapKey = getClientMapKey(clientId);
const existingClients = connectedClients.get(mapKey) || [];
const updatedClients = existingClients.filter(client => client !== ws);
if (updatedClients.length === 0) {
connectedClients.delete(newtId);
logger.info(`All connections removed for Newt ID: ${newtId}`);
connectedClients.delete(mapKey);
if (redisManager.isRedisEnabled()) {
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
await redisManager.del(getNodeConnectionsKey(NODE_ID, clientId));
}
logger.info(`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`);
} else {
connectedClients.set(newtId, updatedClients);
logger.info(`Connection removed - Newt ID: ${newtId}, Remaining connections: ${updatedClients.length}`);
connectedClients.set(mapKey, updatedClients);
if (redisManager.isRedisEnabled() && ws.connectionId) {
await redisManager.hdel(getNodeConnectionsKey(NODE_ID, clientId), ws.connectionId);
}
logger.info(`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`);
}
};
// Helper functions for sending messages
const sendToClient = (newtId: string, message: WSMessage): boolean => {
const clients = connectedClients.get(newtId);
// Local message sending (within this node)
const sendToClientLocal = async (clientId: string, message: WSMessage): Promise<boolean> => {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
if (!clients || clients.length === 0) {
logger.info(`No active connections found for Newt ID: ${newtId}`);
return false;
}
const messageString = JSON.stringify(message);
clients.forEach(client => {
if (client.readyState === WebSocket.OPEN) {
@@ -91,9 +165,10 @@ const sendToClient = (newtId: string, message: WSMessage): boolean => {
return true;
};
const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void => {
connectedClients.forEach((clients, newtId) => {
if (newtId !== excludeNewtId) {
const broadcastToAllExceptLocal = async (message: WSMessage, excludeClientId?: string): Promise<void> => {
connectedClients.forEach((clients, mapKey) => {
const [type, id] = mapKey.split(":");
if (!(excludeClientId && id === excludeClientId)) {
clients.forEach(client => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(message));
@@ -103,84 +178,152 @@ const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void
});
};
// Token verification middleware (unchanged)
const verifyToken = async (token: string): Promise<TokenPayload | null> => {
try {
const { session, newt } = await validateNewtSessionToken(token);
// Cross-node message sending (via Redis)
const sendToClient = async (clientId: string, message: WSMessage): Promise<boolean> => {
// Try to send locally first
const localSent = await sendToClientLocal(clientId, message);
if (!session || !newt) {
return null;
// If Redis is enabled, also send via Redis pub/sub to other nodes
if (redisManager.isRedisEnabled()) {
const redisMessage: RedisMessage = {
type: 'direct',
targetClientId: clientId,
message,
fromNodeId: NODE_ID
};
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
}
return localSent;
};
const broadcastToAllExcept = async (message: WSMessage, excludeClientId?: string): Promise<void> => {
// Broadcast locally
await broadcastToAllExceptLocal(message, excludeClientId);
// If Redis is enabled, also broadcast via Redis pub/sub to other nodes
if (redisManager.isRedisEnabled()) {
const redisMessage: RedisMessage = {
type: 'broadcast',
excludeClientId,
message,
fromNodeId: NODE_ID
};
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
}
};
// Check if a client has active connections across all nodes
const hasActiveConnections = async (clientId: string): Promise<boolean> => {
if (!redisManager.isRedisEnabled()) {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
return !!(clients && clients.length > 0);
}
const activeNodes = await redisManager.smembers(getConnectionsKey(clientId));
return activeNodes.length > 0;
};
// Get all active nodes for a client
const getActiveNodes = async (clientType: ClientType, clientId: string): Promise<string[]> => {
if (!redisManager.isRedisEnabled()) {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
return (clients && clients.length > 0) ? [NODE_ID] : [];
}
return await redisManager.smembers(getConnectionsKey(clientId));
};
// Token verification middleware
const verifyToken = async (token: string, clientType: ClientType): Promise<TokenPayload | null> => {
try {
if (clientType === 'newt') {
const { session, newt } = await validateNewtSessionToken(token);
if (!session || !newt) {
return null;
}
const existingNewt = await db
.select()
.from(newts)
.where(eq(newts.newtId, newt.newtId));
if (!existingNewt || !existingNewt[0]) {
return null;
}
return { client: existingNewt[0], session, clientType };
} else {
const { session, olm } = await validateOlmSessionToken(token);
if (!session || !olm) {
return null;
}
const existingOlm = await db
.select()
.from(olms)
.where(eq(olms.olmId, olm.olmId));
if (!existingOlm || !existingOlm[0]) {
return null;
}
return { client: existingOlm[0], session, clientType };
}
const existingNewt = await db
.select()
.from(newts)
.where(eq(newts.newtId, newt.newtId));
if (!existingNewt || !existingNewt[0]) {
return null;
}
return { newt: existingNewt[0], session };
} catch (error) {
logger.error("Token verification failed:", error);
return null;
}
};
const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => {
const setupConnection = async (ws: AuthenticatedWebSocket, client: Newt | Olm, clientType: ClientType): Promise<void> => {
logger.info("Establishing websocket connection");
if (!newt) {
logger.error("Connection attempt without newt");
if (!client) {
logger.error("Connection attempt without client");
return ws.terminate();
}
ws.newt = newt;
ws.client = client;
ws.clientType = clientType;
// Add client to tracking
addClient(newt.newtId, ws);
const clientId = clientType === 'newt' ? (client as Newt).newtId : (client as Olm).olmId;
await addClient(clientType, clientId, ws);
ws.on("message", async (data) => {
try {
const message: WSMessage = JSON.parse(data.toString());
// logger.info(`Message received from Newt ID ${newtId}:`, message);
// Validate message format
if (!message.type || typeof message.type !== "string") {
throw new Error("Invalid message format: missing or invalid type");
}
// Get the appropriate handler for the message type
const handler = messageHandlers[message.type];
if (!handler) {
throw new Error(`Unsupported message type: ${message.type}`);
}
// Process the message and get response
const response = await handler({
message,
senderWs: ws,
newt: ws.newt,
client: ws.client,
clientType: ws.clientType!,
sendToClient,
broadcastToAllExcept,
connectedClients
});
// Send response if one was returned
if (response) {
if (response.broadcast) {
// Broadcast to all clients except sender if specified
broadcastToAllExcept(response.message, response.excludeSender ? newt.newtId : undefined);
} else if (response.targetNewtId) {
// Send to specific client if targetNewtId is provided
sendToClient(response.targetNewtId, response.message);
await broadcastToAllExcept(
response.message,
response.excludeSender ? clientId : undefined
);
} else if (response.targetClientId) {
await sendToClient(response.targetClientId, response.message);
} else {
// Send back to sender
ws.send(JSON.stringify(response.message));
}
}
} catch (error) {
logger.error("Message handling error:", error);
ws.send(JSON.stringify({
@@ -194,18 +337,18 @@ const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => {
});
ws.on("close", () => {
removeClient(newt.newtId, ws);
logger.info(`Client disconnected - Newt ID: ${newt.newtId}`);
removeClient(clientType, clientId, ws);
logger.info(`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`);
});
ws.on("error", (error: Error) => {
logger.error(`WebSocket error for Newt ID ${newt.newtId}:`, error);
logger.error(`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`, error);
});
logger.info(`WebSocket connection established - Newt ID: ${newt.newtId}`);
logger.info(`WebSocket connection established - ${clientType.toUpperCase()} ID: ${clientId}`);
};
// Router endpoint (unchanged)
// Router endpoint
router.get("/ws", (req: Request, res: Response) => {
res.status(200).send("WebSocket endpoint");
});
@@ -214,18 +357,22 @@ router.get("/ws", (req: Request, res: Response) => {
const handleWSUpgrade = (server: HttpServer): void => {
server.on("upgrade", async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
try {
const token = request.url?.includes("?")
? new URLSearchParams(request.url.split("?")[1]).get("token") || ""
: request.headers["sec-websocket-protocol"];
const url = new URL(request.url || '', `http://${request.headers.host}`);
const token = url.searchParams.get('token') || request.headers["sec-websocket-protocol"] || '';
let clientType = url.searchParams.get('clientType') as ClientType;
if (!token) {
logger.warn("Unauthorized connection attempt: no token...");
if (!clientType) {
clientType = "newt";
}
if (!token || !clientType || !['newt', 'olm'].includes(clientType)) {
logger.warn("Unauthorized connection attempt: invalid token or client type...");
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
socket.destroy();
return;
}
const tokenPayload = await verifyToken(token);
const tokenPayload = await verifyToken(token, clientType);
if (!tokenPayload) {
logger.warn("Unauthorized connection attempt: invalid token...");
socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
@@ -234,7 +381,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
}
wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => {
setupConnection(ws, tokenPayload.newt);
setupConnection(ws, tokenPayload.client, tokenPayload.clientType);
});
} catch (error) {
logger.error("WebSocket upgrade error:", error);
@@ -244,10 +391,54 @@ const handleWSUpgrade = (server: HttpServer): void => {
});
};
// Initialize Redis subscription when the module is loaded
if (redisManager.isRedisEnabled()) {
initializeRedisSubscription().catch(error => {
logger.error('Failed to initialize Redis subscription:', error);
});
logger.info(`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`);
} else {
logger.debug('WebSocket handler initialized in local mode (Redis disabled)');
}
// Cleanup function for graceful shutdown
const cleanup = async (): Promise<void> => {
try {
// Close all WebSocket connections
connectedClients.forEach((clients) => {
clients.forEach(client => {
if (client.readyState === WebSocket.OPEN) {
client.terminate();
}
});
});
// Clean up Redis tracking for this node
if (redisManager.isRedisEnabled()) {
const keys = await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`) || [];
if (keys.length > 0) {
await Promise.all(keys.map(key => redisManager.del(key)));
}
}
logger.info('WebSocket cleanup completed');
} catch (error) {
logger.error('Error during WebSocket cleanup:', error);
}
};
// Handle process termination
process.on('SIGTERM', cleanup);
process.on('SIGINT', cleanup);
export {
router,
handleWSUpgrade,
sendToClient,
broadcastToAllExcept,
connectedClients
connectedClients,
hasActiveConnections,
getActiveNodes,
NODE_ID,
cleanup
};