Merge branch 'hp-multi-client' into auth-providers-clients

This commit is contained in:
miloschwartz
2025-04-20 16:15:40 -04:00
70 changed files with 27368 additions and 159 deletions

View File

@@ -24,7 +24,7 @@ export const newtGetTokenBodySchema = z.object({
export type NewtGetTokenBody = z.infer<typeof newtGetTokenBodySchema>;
export async function getToken(
export async function getNewtToken(
req: Request,
res: Response,
next: NextFunction

View File

@@ -0,0 +1,161 @@
import { z } from "zod";
import { MessageHandler } from "../ws";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import db from "@server/db";
import { clients, clientSites, Newt, sites } from "@server/db/schema";
import { eq } from "drizzle-orm";
import { updatePeer } from "../olm/peers";
const inputSchema = z.object({
publicKey: z.string(),
port: z.number().int().positive()
});
type Input = z.infer<typeof inputSchema>;
export const handleGetConfigMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
const now = new Date().getTime() / 1000;
logger.debug("Handling Newt get config message!");
if (!newt) {
logger.warn("Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
return;
}
const parsed = inputSchema.safeParse(message.data);
if (!parsed.success) {
logger.error(
"handleGetConfigMessage: Invalid input: " +
fromError(parsed.error).toString()
);
return;
}
const { publicKey, port } = message.data as Input;
const siteId = newt.siteId;
// Get the current site data
const [existingSite] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId));
if (!existingSite) {
logger.warn("handleGetConfigMessage: Site not found");
return;
}
// todo check if the public key has changed
// we need to wait for hole punch success
if (!existingSite.endpoint) {
logger.warn(`Site ${existingSite.siteId} has no endpoint, skipping`);
return;
}
if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 6) {
logger.warn(
`Site ${existingSite.siteId} last hole punch is too old, skipping`
);
return;
}
// update the endpoint and the public key
const [site] = await db
.update(sites)
.set({
publicKey,
listenPort: port
})
.where(eq(sites.siteId, siteId))
.returning();
if (!site) {
logger.error("handleGetConfigMessage: Failed to update site");
return;
}
// Get all clients connected to this site
const clientsRes = await db
.select()
.from(clients)
.innerJoin(clientSites, eq(clients.clientId, clientSites.clientId))
.where(eq(clientSites.siteId, siteId));
// Prepare peers data for the response
const peers = await Promise.all(
clientsRes
.filter((client) => {
if (!client.clients.pubKey) {
return false;
}
if (!client.clients.subnet) {
return false;
}
if (!client.clients.endpoint) {
return false;
}
if (!client.clients.online) {
return false;
}
return true;
})
.map(async (client) => {
// Add or update this peer on the olm if it is connected
try {
if (site.endpoint && site.publicKey) {
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort
});
}
} catch (error) {
logger.error(
`Failed to add/update peer ${client.clients.pubKey} to newt ${newt.newtId}: ${error}`
);
}
return {
publicKey: client.clients.pubKey!,
allowedIps: [client.clients.subnet!],
endpoint: client.clientSites.isRelayed
? ""
: client.clients.endpoint! // if its relayed it should be localhost
};
})
);
// Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null);
// Build the configuration response
const configResponse = {
ipAddress: site.address,
peers: validPeers
};
logger.debug("Sending config: ", configResponse);
return {
message: {
type: "newt/wg/receive-config",
data: {
...configResponse
}
},
broadcast: false,
excludeSender: false
};
};

View File

@@ -2,6 +2,7 @@ import db from "@server/db";
import { MessageHandler } from "../ws";
import {
exitNodes,
Newt,
resources,
sites,
Target,
@@ -11,10 +12,11 @@ import { eq, and, sql, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger";
export const handleRegisterMessage: MessageHandler = async (context) => {
const { message, newt, sendToClient } = context;
export const handleNewtRegisterMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
logger.info("Handling register message!");
logger.info("Handling register newt message!");
if (!newt) {
logger.warn("Newt not found");

View File

@@ -0,0 +1,52 @@
import db from "@server/db";
import { MessageHandler } from "../ws";
import { clients, Newt } from "@server/db/schema";
import { eq } from "drizzle-orm";
import logger from "@server/logger";
interface PeerBandwidth {
publicKey: string;
bytesIn: number;
bytesOut: number;
}
export const handleReceiveBandwidthMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
if (!message.data.bandwidthData) {
logger.warn("No bandwidth data provided");
}
const bandwidthData: PeerBandwidth[] = message.data.bandwidthData;
if (!Array.isArray(bandwidthData)) {
throw new Error("Invalid bandwidth data");
}
await db.transaction(async (trx) => {
for (const peer of bandwidthData) {
const { publicKey, bytesIn, bytesOut } = peer;
// Find the client by public key
const [client] = await trx
.select()
.from(clients)
.where(eq(clients.pubKey, publicKey))
.limit(1);
if (!client) {
continue;
}
// Update the client's bandwidth usage
await trx
.update(clients)
.set({
megabytesOut: (client.megabytesIn || 0) + bytesIn,
megabytesIn: (client.megabytesOut || 0) + bytesOut,
lastBandwidthUpdate: new Date().toISOString(),
})
.where(eq(clients.clientId, client.clientId));
}
});
};

View File

@@ -1,3 +1,5 @@
export * from "./createNewt";
export * from "./getToken";
export * from "./handleRegisterMessage";
export * from "./getNewtToken";
export * from "./handleNewtRegisterMessage";
export * from "./handleReceiveBandwidthMessage";
export * from "./handleGetConfigMessage";

View File

@@ -0,0 +1,78 @@
import db from '@server/db';
import { newts, sites } from '@server/db/schema';
import { eq } from 'drizzle-orm';
import { sendToClient } from '../ws';
import logger from '@server/logger';
export async function addPeer(siteId: number, peer: {
publicKey: string;
allowedIps: string[];
endpoint: string;
}) {
const [site] = await db.select().from(sites).where(eq(sites.siteId, siteId)).limit(1);
if (!site) {
throw new Error(`Exit node with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db.select().from(newts).where(eq(newts.siteId, siteId)).limit(1);
if (!newt) {
throw new Error(`Site found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: 'newt/wg/peer/add',
data: peer
});
logger.info(`Added peer ${peer.publicKey} to newt ${newt.newtId}`);
}
export async function deletePeer(siteId: number, publicKey: string) {
const [site] = await db.select().from(sites).where(eq(sites.siteId, siteId)).limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db.select().from(newts).where(eq(newts.siteId, siteId)).limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: 'newt/wg/peer/remove',
data: {
publicKey
}
});
logger.info(`Deleted peer ${publicKey} from newt ${newt.newtId}`);
}
export async function updatePeer(siteId: number, publicKey: string, peer: {
allowedIps?: string[];
endpoint?: string;
}) {
const [site] = await db.select().from(sites).where(eq(sites.siteId, siteId)).limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db.select().from(newts).where(eq(newts.siteId, siteId)).limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: 'newt/wg/peer/update',
data: {
publicKey,
...peer
}
});
logger.info(`Updated peer ${publicKey} on newt ${newt.newtId}`);
}