Add version and send it down

This commit is contained in:
Owen
2025-12-19 16:44:57 -05:00
parent 9bd66fa306
commit 322f3bfb1d
4 changed files with 194 additions and 29 deletions

View File

@@ -573,6 +573,20 @@ class RedisManager {
} }
} }
public async incr(key: string): Promise<number> {
if (!this.isRedisEnabled() || !this.writeClient) return 0;
try {
return await this.executeWithRetry(
() => this.writeClient!.incr(key),
"Redis INCR"
);
} catch (error) {
logger.error("Redis INCR error:", error);
return 0;
}
}
public async sadd(key: string, member: string): Promise<boolean> { public async sadd(key: string, member: string): Promise<boolean> {
if (!this.isRedisEnabled() || !this.writeClient) return false; if (!this.isRedisEnabled() || !this.writeClient) return false;

View File

@@ -43,7 +43,8 @@ import {
WSMessage, WSMessage,
TokenPayload, TokenPayload,
WebSocketRequest, WebSocketRequest,
RedisMessage RedisMessage,
SendMessageOptions
} from "@server/routers/ws"; } from "@server/routers/ws";
import { validateSessionToken } from "@server/auth/sessions/app"; import { validateSessionToken } from "@server/auth/sessions/app";
@@ -172,6 +173,9 @@ const REDIS_CHANNEL = "websocket_messages";
// Client tracking map (local to this node) // Client tracking map (local to this node)
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map(); const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
// Config version tracking map (local to this node, resets on server restart)
const clientConfigVersions: Map<string, number> = new Map();
// Recovery tracking // Recovery tracking
let isRedisRecoveryInProgress = false; let isRedisRecoveryInProgress = false;
@@ -182,6 +186,7 @@ const getClientMapKey = (clientId: string) => clientId;
const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`; const getConnectionsKey = (clientId: string) => `ws:connections:${clientId}`;
const getNodeConnectionsKey = (nodeId: string, clientId: string) => const getNodeConnectionsKey = (nodeId: string, clientId: string) =>
`ws:node:${nodeId}:${clientId}`; `ws:node:${nodeId}:${clientId}`;
const getConfigVersionKey = (clientId: string) => `ws:configVersion:${clientId}`;
// Initialize Redis subscription for cross-node messaging // Initialize Redis subscription for cross-node messaging
const initializeRedisSubscription = async (): Promise<void> => { const initializeRedisSubscription = async (): Promise<void> => {
@@ -377,17 +382,76 @@ const removeClient = async (
} }
}; };
// Helper to get the current config version for a client
const getClientConfigVersion = async (clientId: string): Promise<number> => {
// Try Redis first if available
if (redisManager.isRedisEnabled()) {
try {
const redisVersion = await redisManager.get(getConfigVersionKey(clientId));
if (redisVersion !== null) {
const version = parseInt(redisVersion, 10);
// Sync local cache with Redis
clientConfigVersions.set(clientId, version);
return version;
}
} catch (error) {
logger.error("Failed to get config version from Redis:", error);
}
}
// Fall back to local cache
return clientConfigVersions.get(clientId) || 0;
};
// Helper to increment and get the new config version for a client
const incrementClientConfigVersion = async (clientId: string): Promise<number> => {
let newVersion: number;
if (redisManager.isRedisEnabled()) {
try {
// Use Redis INCR for atomic increment across nodes
newVersion = await redisManager.incr(getConfigVersionKey(clientId));
// Sync local cache
clientConfigVersions.set(clientId, newVersion);
return newVersion;
} catch (error) {
logger.error("Failed to increment config version in Redis:", error);
// Fall through to local increment
}
}
// Local increment
const currentVersion = clientConfigVersions.get(clientId) || 0;
newVersion = currentVersion + 1;
clientConfigVersions.set(clientId, newVersion);
return newVersion;
};
// Local message sending (within this node) // Local message sending (within this node)
const sendToClientLocal = async ( const sendToClientLocal = async (
clientId: string, clientId: string,
message: WSMessage message: WSMessage,
options: SendMessageOptions = {}
): Promise<boolean> => { ): Promise<boolean> => {
const mapKey = getClientMapKey(clientId); const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey); const clients = connectedClients.get(mapKey);
if (!clients || clients.length === 0) { if (!clients || clients.length === 0) {
return false; return false;
} }
const messageString = JSON.stringify(message);
// Handle config version
let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId);
}
// Add config version to message
const messageWithVersion = {
...message,
configVersion
};
const messageString = JSON.stringify(messageWithVersion);
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(messageString); client.send(messageString);
@@ -395,43 +459,69 @@ const sendToClientLocal = async (
}); });
logger.debug( logger.debug(
`sendToClient: Message type ${message.type} sent to clientId ${clientId}` `sendToClient: Message type ${message.type} sent to clientId ${clientId} (configVersion: ${configVersion})`
); );
return true; return true;
}; };
const broadcastToAllExceptLocal = async ( const broadcastToAllExceptLocal = async (
message: WSMessage, message: WSMessage,
excludeClientId?: string excludeClientId?: string,
options: SendMessageOptions = {}
): Promise<void> => { ): Promise<void> => {
connectedClients.forEach((clients, mapKey) => { for (const [mapKey, clients] of connectedClients.entries()) {
const [type, id] = mapKey.split(":"); const [type, id] = mapKey.split(":");
if (!(excludeClientId && id === excludeClientId)) { const clientId = mapKey; // mapKey is the clientId
if (!(excludeClientId && clientId === excludeClientId)) {
// Handle config version per client
let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId);
}
// Add config version to message
const messageWithVersion = {
...message,
configVersion
};
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(message)); client.send(JSON.stringify(messageWithVersion));
} }
}); });
} }
}); }
}; };
// Cross-node message sending (via Redis) // Cross-node message sending (via Redis)
const sendToClient = async ( const sendToClient = async (
clientId: string, clientId: string,
message: WSMessage message: WSMessage,
options: SendMessageOptions = {}
): Promise<boolean> => { ): Promise<boolean> => {
// Try to send locally first // Try to send locally first
const localSent = await sendToClientLocal(clientId, message); const localSent = await sendToClientLocal(clientId, message, options);
// Only send via Redis if the client is not connected locally and Redis is enabled // Only send via Redis if the client is not connected locally and Redis is enabled
if (!localSent && redisManager.isRedisEnabled()) { if (!localSent && redisManager.isRedisEnabled()) {
try { try {
// If we need to increment config version, do it before sending via Redis
// so remote nodes send the correct version
let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId);
}
const redisMessage: RedisMessage = { const redisMessage: RedisMessage = {
type: "direct", type: "direct",
targetClientId: clientId, targetClientId: clientId,
message, message: {
...message,
configVersion
},
fromNodeId: NODE_ID fromNodeId: NODE_ID
}; };
@@ -458,19 +548,22 @@ const sendToClient = async (
const broadcastToAllExcept = async ( const broadcastToAllExcept = async (
message: WSMessage, message: WSMessage,
excludeClientId?: string excludeClientId?: string,
options: SendMessageOptions = {}
): Promise<void> => { ): Promise<void> => {
// Broadcast locally // Broadcast locally
await broadcastToAllExceptLocal(message, excludeClientId); await broadcastToAllExceptLocal(message, excludeClientId, options);
// If Redis is enabled, also broadcast via Redis pub/sub to other nodes // If Redis is enabled, also broadcast via Redis pub/sub to other nodes
// Note: For broadcasts, we include the options so remote nodes can handle versioning
if (redisManager.isRedisEnabled()) { if (redisManager.isRedisEnabled()) {
try { try {
const redisMessage: RedisMessage = { const redisMessage: RedisMessage = {
type: "broadcast", type: "broadcast",
excludeClientId, excludeClientId,
message, message,
fromNodeId: NODE_ID fromNodeId: NODE_ID,
options
}; };
await redisManager.publish( await redisManager.publish(
@@ -936,5 +1029,6 @@ export {
getActiveNodes, getActiveNodes,
disconnectClient, disconnectClient,
NODE_ID, NODE_ID,
cleanup cleanup,
getClientConfigVersion
}; };

View File

@@ -25,6 +25,7 @@ export interface AuthenticatedWebSocket extends WebSocket {
connectionId?: string; connectionId?: string;
isFullyConnected?: boolean; isFullyConnected?: boolean;
pendingMessages?: Buffer[]; pendingMessages?: Buffer[];
configVersion?: number;
} }
export interface TokenPayload { export interface TokenPayload {
@@ -36,6 +37,7 @@ export interface TokenPayload {
export interface WSMessage { export interface WSMessage {
type: string; type: string;
data: any; data: any;
configVersion?: number;
} }
export interface HandlerResponse { export interface HandlerResponse {
@@ -50,10 +52,11 @@ export interface HandlerContext {
senderWs: WebSocket; senderWs: WebSocket;
client: Newt | Olm | RemoteExitNode | undefined; client: Newt | Olm | RemoteExitNode | undefined;
clientType: ClientType; clientType: ClientType;
sendToClient: (clientId: string, message: WSMessage) => Promise<boolean>; sendToClient: (clientId: string, message: WSMessage, options?: SendMessageOptions) => Promise<boolean>;
broadcastToAllExcept: ( broadcastToAllExcept: (
message: WSMessage, message: WSMessage,
excludeClientId?: string excludeClientId?: string,
options?: SendMessageOptions
) => Promise<void>; ) => Promise<void>;
connectedClients: Map<string, WebSocket[]>; connectedClients: Map<string, WebSocket[]>;
} }
@@ -62,6 +65,11 @@ export type MessageHandler = (
context: HandlerContext context: HandlerContext
) => Promise<HandlerResponse | void>; ) => Promise<HandlerResponse | void>;
// Options for sending messages with config version tracking
export interface SendMessageOptions {
incrementConfigVersion?: boolean;
}
// Redis message type for cross-node communication // Redis message type for cross-node communication
export interface RedisMessage { export interface RedisMessage {
type: "direct" | "broadcast"; type: "direct" | "broadcast";
@@ -69,4 +77,5 @@ export interface RedisMessage {
excludeClientId?: string; excludeClientId?: string;
message: WSMessage; message: WSMessage;
fromNodeId: string; fromNodeId: string;
options?: SendMessageOptions;
} }

View File

@@ -15,7 +15,8 @@ import {
TokenPayload, TokenPayload,
WebSocketRequest, WebSocketRequest,
WSMessage, WSMessage,
AuthenticatedWebSocket AuthenticatedWebSocket,
SendMessageOptions
} from "./types"; } from "./types";
import { validateSessionToken } from "@server/auth/sessions/app"; import { validateSessionToken } from "@server/auth/sessions/app";
@@ -34,6 +35,8 @@ const NODE_ID = uuidv4();
// Client tracking map (local to this node) // Client tracking map (local to this node)
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map(); const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
// Config version tracking map (clientId -> version)
const clientConfigVersions: Map<string, number> = new Map();
// Helper to get map key // Helper to get map key
const getClientMapKey = (clientId: string) => clientId; const getClientMapKey = (clientId: string) => clientId;
@@ -84,14 +87,34 @@ const removeClient = async (
// Local message sending (within this node) // Local message sending (within this node)
const sendToClientLocal = async ( const sendToClientLocal = async (
clientId: string, clientId: string,
message: WSMessage message: WSMessage,
options: SendMessageOptions = {}
): Promise<boolean> => { ): Promise<boolean> => {
const mapKey = getClientMapKey(clientId); const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey); const clients = connectedClients.get(mapKey);
if (!clients || clients.length === 0) { if (!clients || clients.length === 0) {
return false; return false;
} }
const messageString = JSON.stringify(message);
// Increment config version if requested
if (options.incrementConfigVersion) {
const currentVersion = clientConfigVersions.get(clientId) || 0;
const newVersion = currentVersion + 1;
clientConfigVersions.set(clientId, newVersion);
// Update version on all client connections
clients.forEach((client) => {
client.configVersion = newVersion;
});
}
// Include config version in message
const configVersion = clientConfigVersions.get(clientId) || 0;
const messageWithVersion = {
...message,
configVersion
};
const messageString = JSON.stringify(messageWithVersion);
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(messageString); client.send(messageString);
@@ -102,14 +125,31 @@ const sendToClientLocal = async (
const broadcastToAllExceptLocal = async ( const broadcastToAllExceptLocal = async (
message: WSMessage, message: WSMessage,
excludeClientId?: string excludeClientId?: string,
options: SendMessageOptions = {}
): Promise<void> => { ): Promise<void> => {
connectedClients.forEach((clients, mapKey) => { connectedClients.forEach((clients, mapKey) => {
const [type, id] = mapKey.split(":"); const [type, id] = mapKey.split(":");
if (!(excludeClientId && id === excludeClientId)) { const clientId = mapKey; // mapKey is the clientId
if (!(excludeClientId && clientId === excludeClientId)) {
// Handle config version per client
if (options.incrementConfigVersion) {
const currentVersion = clientConfigVersions.get(clientId) || 0;
const newVersion = currentVersion + 1;
clientConfigVersions.set(clientId, newVersion);
clients.forEach((client) => {
client.configVersion = newVersion;
});
}
// Include config version in message for this client
const configVersion = clientConfigVersions.get(clientId) || 0;
const messageWithVersion = {
...message,
configVersion
};
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(message)); client.send(JSON.stringify(messageWithVersion));
} }
}); });
} }
@@ -119,10 +159,11 @@ const broadcastToAllExceptLocal = async (
// Cross-node message sending // Cross-node message sending
const sendToClient = async ( const sendToClient = async (
clientId: string, clientId: string,
message: WSMessage message: WSMessage,
options: SendMessageOptions = {}
): Promise<boolean> => { ): Promise<boolean> => {
// Try to send locally first // Try to send locally first
const localSent = await sendToClientLocal(clientId, message); const localSent = await sendToClientLocal(clientId, message, options);
logger.debug( logger.debug(
`sendToClient: Message type ${message.type} sent to clientId ${clientId}` `sendToClient: Message type ${message.type} sent to clientId ${clientId}`
@@ -133,10 +174,11 @@ const sendToClient = async (
const broadcastToAllExcept = async ( const broadcastToAllExcept = async (
message: WSMessage, message: WSMessage,
excludeClientId?: string excludeClientId?: string,
options: SendMessageOptions = {}
): Promise<void> => { ): Promise<void> => {
// Broadcast locally // Broadcast locally
await broadcastToAllExceptLocal(message, excludeClientId); await broadcastToAllExceptLocal(message, excludeClientId, options);
}; };
// Check if a client has active connections across all nodes // Check if a client has active connections across all nodes
@@ -146,6 +188,11 @@ const hasActiveConnections = async (clientId: string): Promise<boolean> => {
return !!(clients && clients.length > 0); return !!(clients && clients.length > 0);
}; };
// Get the current config version for a client
const getClientConfigVersion = (clientId: string): number => {
return clientConfigVersions.get(clientId) || 0;
};
// Get all active nodes for a client // Get all active nodes for a client
const getActiveNodes = async ( const getActiveNodes = async (
clientType: ClientType, clientType: ClientType,
@@ -434,5 +481,6 @@ export {
getActiveNodes, getActiveNodes,
disconnectClient, disconnectClient,
NODE_ID, NODE_ID,
cleanup cleanup,
getClientConfigVersion
}; };