Add increment options and slight cleanup

This commit is contained in:
Owen
2026-01-12 20:48:18 -08:00
parent 0ccd5714f9
commit eba25fcc4d
14 changed files with 92 additions and 53 deletions

View File

@@ -50,10 +50,14 @@ export async function sendToExitNode(
); );
} }
return sendToClient(remoteExitNode.remoteExitNodeId, { return sendToClient(
type: request.remoteType, remoteExitNode.remoteExitNodeId,
data: request.data {
}); type: request.remoteType,
data: request.data
},
{ incrementConfigVersion: true }
);
} else { } else {
let hostname = exitNode.reachableAt; let hostname = exitNode.reachableAt;

View File

@@ -119,12 +119,21 @@ const processMessage = async (
if (response.broadcast) { if (response.broadcast) {
await broadcastToAllExcept( await broadcastToAllExcept(
response.message, response.message,
response.excludeSender ? clientId : undefined response.excludeSender ? clientId : undefined,
response.options
); );
} else if (response.targetClientId) { } else if (response.targetClientId) {
await sendToClient(response.targetClientId, response.message); await sendToClient(
response.targetClientId,
response.message,
response.options
);
} else { } else {
ws.send(JSON.stringify(response.message)); await sendToClient(
clientId,
response.message,
response.options
);
} }
} }
} catch (error) { } catch (error) {
@@ -186,7 +195,8 @@ const getClientMapKey = (clientId: string) => clientId;
const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`; const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`;
const getNodeConnectionsKey = (nodeId: string, clientId: string) => const getNodeConnectionsKey = (nodeId: string, clientId: string) =>
`ws:node:${nodeId}:${clientId}`; `ws:node:${nodeId}:${clientId}`;
const getConfigVersionKey = (clientId: string) => `ws:configVersion:${clientId}`; const getConfigVersionKey = (clientId: string) =>
`ws:configVersion:${clientId}`;
// Initialize Redis subscription for cross-node messaging // Initialize Redis subscription for cross-node messaging
const initializeRedisSubscription = async (): Promise<void> => { const initializeRedisSubscription = async (): Promise<void> => {
@@ -387,7 +397,9 @@ const getClientConfigVersion = async (clientId: string): Promise<number> => {
// Try Redis first if available // Try Redis first if available
if (redisManager.isRedisEnabled()) { if (redisManager.isRedisEnabled()) {
try { try {
const redisVersion = await redisManager.get(getConfigVersionKey(clientId)); const redisVersion = await redisManager.get(
getConfigVersionKey(clientId)
);
if (redisVersion !== null) { if (redisVersion !== null) {
const version = parseInt(redisVersion, 10); const version = parseInt(redisVersion, 10);
// Sync local cache with Redis // Sync local cache with Redis
@@ -398,15 +410,17 @@ const getClientConfigVersion = async (clientId: string): Promise<number> => {
logger.error("Failed to get config version from Redis:", error); logger.error("Failed to get config version from Redis:", error);
} }
} }
// Fall back to local cache // Fall back to local cache
return clientConfigVersions.get(clientId) || 0; return clientConfigVersions.get(clientId) || 0;
}; };
// Helper to increment and get the new config version for a client // Helper to increment and get the new config version for a client
const incrementClientConfigVersion = async (clientId: string): Promise<number> => { const incrementClientConfigVersion = async (
clientId: string
): Promise<number> => {
let newVersion: number; let newVersion: number;
if (redisManager.isRedisEnabled()) { if (redisManager.isRedisEnabled()) {
try { try {
// Use Redis INCR for atomic increment across nodes // Use Redis INCR for atomic increment across nodes
@@ -419,7 +433,7 @@ const incrementClientConfigVersion = async (clientId: string): Promise<number> =
// Fall through to local increment // Fall through to local increment
} }
} }
// Local increment // Local increment
const currentVersion = clientConfigVersions.get(clientId) || 0; const currentVersion = clientConfigVersions.get(clientId) || 0;
newVersion = currentVersion + 1; newVersion = currentVersion + 1;
@@ -438,19 +452,19 @@ const sendToClientLocal = async (
if (!clients || clients.length === 0) { if (!clients || clients.length === 0) {
return false; return false;
} }
// Handle config version // Handle config version
let configVersion = await getClientConfigVersion(clientId); let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) { if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId); configVersion = await incrementClientConfigVersion(clientId);
} }
// Add config version to message // Add config version to message
const messageWithVersion = { const messageWithVersion = {
...message, ...message,
configVersion configVersion
}; };
const messageString = JSON.stringify(messageWithVersion); const messageString = JSON.stringify(messageWithVersion);
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
@@ -462,7 +476,6 @@ const sendToClientLocal = async (
`sendToClient: Message type ${message.type} sent to clientId ${clientId} (configVersion: ${configVersion})` `sendToClient: Message type ${message.type} sent to clientId ${clientId} (configVersion: ${configVersion})`
); );
return true; return true;
}; };
@@ -480,13 +493,13 @@ const broadcastToAllExceptLocal = async (
if (options.incrementConfigVersion) { if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId); configVersion = await incrementClientConfigVersion(clientId);
} }
// Add config version to message // Add config version to message
const messageWithVersion = { const messageWithVersion = {
...message, ...message,
configVersion configVersion
}; };
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion)); client.send(JSON.stringify(messageWithVersion));
@@ -514,7 +527,7 @@ const sendToClient = async (
if (options.incrementConfigVersion) { if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId); configVersion = await incrementClientConfigVersion(clientId);
} }
const redisMessage: RedisMessage = { const redisMessage: RedisMessage = {
type: "direct", type: "direct",
targetClientId: clientId, targetClientId: clientId,

View File

@@ -28,7 +28,7 @@ export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
await sendToClient(newtId, { await sendToClient(newtId, {
type: `newt/wg/targets/add`, type: `newt/wg/targets/add`,
data: batches[i] data: batches[i]
}); }, { incrementConfigVersion: true });
} }
} }
@@ -44,7 +44,7 @@ export async function removeTargets(
await sendToClient(newtId, { await sendToClient(newtId, {
type: `newt/wg/targets/remove`, type: `newt/wg/targets/remove`,
data: batches[i] data: batches[i]
}); },{ incrementConfigVersion: true });
} }
} }
@@ -69,7 +69,7 @@ export async function updateTargets(
oldTargets: oldBatches[i] || [], oldTargets: oldBatches[i] || [],
newTargets: newBatches[i] || [] newTargets: newBatches[i] || []
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }
@@ -101,7 +101,7 @@ export async function addPeerData(
remoteSubnets: remoteSubnets, remoteSubnets: remoteSubnets,
aliases: aliases aliases: aliases
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }
@@ -132,7 +132,7 @@ export async function removePeerData(
remoteSubnets: remoteSubnets, remoteSubnets: remoteSubnets,
aliases: aliases aliases: aliases
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }
@@ -173,7 +173,7 @@ export async function updatePeerData(
...remoteSubnets, ...remoteSubnets,
...aliases ...aliases
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }

View File

@@ -153,6 +153,6 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
} }
}, },
broadcast: false, broadcast: false,
excludeSender: false, excludeSender: false
}; };
}; };

View File

@@ -110,7 +110,9 @@ export const handleNewtPingMessage: MessageHandler = async (context) => {
const configVersion = await getClientConfigVersion(newt.newtId); const configVersion = await getClientConfigVersion(newt.newtId);
if (message.configVersion && configVersion != message.configVersion) { if (message.configVersion && configVersion != message.configVersion) {
logger.warn(`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`); logger.warn(
`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
);
// TODO: sync the client // TODO: sync the client
} }

View File

@@ -39,7 +39,7 @@ export async function addPeer(
await sendToClient(newtId, { await sendToClient(newtId, {
type: "newt/wg/peer/add", type: "newt/wg/peer/add",
data: peer data: peer
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -81,7 +81,7 @@ export async function deletePeer(
data: { data: {
publicKey publicKey
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -128,7 +128,7 @@ export async function updatePeer(
publicKey, publicKey,
...peer ...peer
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });

View File

@@ -5,7 +5,11 @@ import logger from "@server/logger";
export async function sendOlmSyncMessage(olm: Olm, client: Client) { export async function sendOlmSyncMessage(olm: Olm, client: Client) {
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT // NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
const siteConfigurations = await buildSiteConfigurationForOlmClient(client, client.pubKey, false); const siteConfigurations = await buildSiteConfigurationForOlmClient(
client,
client.pubKey,
false
);
await sendToClient(olm.olmId, { await sendToClient(olm.olmId, {
type: "olm/sync", type: "olm/sync",

View File

@@ -22,7 +22,7 @@ export async function addTargets(
data: { data: {
targets: payloadTargets targets: payloadTargets
} }
}); }, { incrementConfigVersion: true });
// Create a map for quick lookup // Create a map for quick lookup
const healthCheckMap = new Map<number, TargetHealthCheck>(); const healthCheckMap = new Map<number, TargetHealthCheck>();
@@ -103,7 +103,7 @@ export async function addTargets(
data: { data: {
targets: validHealthCheckTargets targets: validHealthCheckTargets
} }
}); }, { incrementConfigVersion: true });
} }
export async function removeTargets( export async function removeTargets(
@@ -124,7 +124,7 @@ export async function removeTargets(
data: { data: {
targets: payloadTargets targets: payloadTargets
} }
}); }, { incrementConfigVersion: true });
const healthCheckTargets = targets.map((target) => { const healthCheckTargets = targets.map((target) => {
return target.targetId; return target.targetId;
@@ -135,5 +135,5 @@ export async function removeTargets(
data: { data: {
ids: healthCheckTargets ids: healthCheckTargets
} }
}); }, { incrementConfigVersion: true });
} }

View File

@@ -170,18 +170,18 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
lastPing: Math.floor(Date.now() / 1000), lastPing: Math.floor(Date.now() / 1000),
online: true online: true
}) })
.where(eq(clients.clientId, olm.clientId)).returning(); .where(eq(clients.clientId, olm.clientId))
.returning();
// get the version // get the version
const configVersion = await getClientConfigVersion(olm.olmId); const configVersion = await getClientConfigVersion(olm.olmId);
if (message.configVersion && configVersion != message.configVersion) { if (message.configVersion && configVersion != message.configVersion) {
logger.warn(`Olm ping with outdated config version: ${message.configVersion} (current: ${configVersion})`); logger.warn(
`Olm ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
);
await sendOlmSyncMessage(olm, client); await sendOlmSyncMessage(olm, client);
} }
} catch (error) { } catch (error) {
logger.error("Error handling ping message", { error }); logger.error("Error handling ping message", { error });
} }

View File

@@ -157,12 +157,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.where(eq(clientSitesAssociationsCache.clientId, client.clientId)); .where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array // Extract the count value from the result array
const sitesCount = sitesCountResult.length > 0 ? sitesCountResult[0].count : 0; const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations // Prepare an array to store site configurations
logger.debug( logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
`Found ${sitesCount} sites for client ${client.clientId}`
);
// this prevents us from accepting a register from an olm that has not hole punched yet. // this prevents us from accepting a register from an olm that has not hole punched yet.
// the olm will pump the register so we can keep checking // the olm will pump the register so we can keep checking
@@ -175,7 +174,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} }
// NOTE: its important that the client here is the old client and the public key is the new key // NOTE: its important that the client here is the old client and the public key is the new key
const siteConfigurations = await buildSiteConfigurationForOlmClient(client, publicKey, relay); const siteConfigurations = await buildSiteConfigurationForOlmClient(
client,
publicKey,
relay
);
// REMOVED THIS SO IT CREATES THE INTERFACE AND JUST WAITS FOR THE SITES // REMOVED THIS SO IT CREATES THE INTERFACE AND JUST WAITS FOR THE SITES
// if (siteConfigurations.length === 0) { // if (siteConfigurations.length === 0) {

View File

@@ -45,7 +45,7 @@ export async function addPeer(
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
aliases: peer.aliases aliases: peer.aliases
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -76,7 +76,7 @@ export async function deletePeer(
publicKey, publicKey,
siteId: siteId siteId: siteId
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -121,7 +121,7 @@ export async function updatePeer(
remoteSubnets: peer.remoteSubnets, remoteSubnets: peer.remoteSubnets,
aliases: peer.aliases aliases: peer.aliases
} }
}).catch((error) => { }, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -161,6 +161,8 @@ export async function initPeerAddHandshake(
endpoint: peer.exitNode.endpoint endpoint: peer.exitNode.endpoint
} }
} }
// }, { incrementConfigVersion: true }).catch((error) => {
// TODO: DOES THIS NEED TO BE A INCREMENT VERSION? I AM NOT SURE BECAUSE IT WOULD BE TRIGGERED BY THE SYNC?
}).catch((error) => { }).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });

View File

@@ -5,7 +5,11 @@ import logger from "@server/logger";
export async function sendOlmSyncMessage(olm: Olm, client: Client) { export async function sendOlmSyncMessage(olm: Olm, client: Client) {
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT // NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
const siteConfigurations = await buildSiteConfigurationForOlmClient(client, client.pubKey, false); const siteConfigurations = await buildSiteConfigurationForOlmClient(
client,
client.pubKey,
false
);
await sendToClient(olm.olmId, { await sendToClient(olm.olmId, {
type: "olm/sync", type: "olm/sync",

View File

@@ -45,6 +45,7 @@ export interface HandlerResponse {
broadcast?: boolean; broadcast?: boolean;
excludeSender?: boolean; excludeSender?: boolean;
targetClientId?: string; targetClientId?: string;
options?: SendMessageOptions;
} }
export interface HandlerContext { export interface HandlerContext {

View File

@@ -306,15 +306,21 @@ const setupConnection = async (
if (response.broadcast) { if (response.broadcast) {
await broadcastToAllExcept( await broadcastToAllExcept(
response.message, response.message,
response.excludeSender ? clientId : undefined response.excludeSender ? clientId : undefined,
response.options
); );
} else if (response.targetClientId) { } else if (response.targetClientId) {
await sendToClient( await sendToClient(
response.targetClientId, response.targetClientId,
response.message response.message,
response.options
); );
} else { } else {
ws.send(JSON.stringify(response.message)); await sendToClient(
clientId,
response.message,
response.options
);
} }
} }
} catch (error) { } catch (error) {