Properly handle blocked devices

This commit is contained in:
Owen
2026-01-12 21:14:18 -08:00
parent 673cd0fcd1
commit 552adf3200
2 changed files with 71 additions and 58 deletions

View File

@@ -108,7 +108,29 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
return; return;
} }
let client: (typeof clients.$inferSelect) | undefined; if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
try {
// get the client
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, olm.clientId))
.limit(1);
if (!client) {
logger.warn("Client not found for olm ping");
return;
}
if (client.blocked) {
// NOTE: by returning we dont update the lastPing, so the offline checker will eventually disconnect them
logger.debug(`Blocked client ${client.clientId} attempted olm ping`);
return;
}
if (olm.userId) { if (olm.userId) {
// we need to check a user token to make sure its still valid // we need to check a user token to make sure its still valid
@@ -122,26 +144,11 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
logger.warn("User ID mismatch for olm ping"); logger.warn("User ID mismatch for olm ping");
return; return;
} }
if (user.userId !== client.userId) {
// get the client logger.warn("Client user ID mismatch for olm ping");
const [userClient] = await db
.select()
.from(clients)
.where(
and(
eq(clients.olmId, olm.olmId),
eq(clients.userId, olm.userId)
)
)
.limit(1);
if (!userClient) {
logger.warn("Client not found for olm ping");
return; return;
} }
client = userClient;
const sessionId = encodeHexLowerCase( const sessionId = encodeHexLowerCase(
sha256(new TextEncoder().encode(userToken)) sha256(new TextEncoder().encode(userToken))
); );
@@ -160,13 +167,6 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
} }
} }
if (!olm.clientId) {
logger.warn("Olm has no client ID!");
return;
}
try {
// Update the client's last ping timestamp
await db await db
.update(clients) .update(clients)
.set({ .set({
@@ -176,7 +176,12 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
}) })
.where(eq(clients.clientId, olm.clientId)); .where(eq(clients.clientId, olm.clientId));
await db.update(olms).set({ archived: false }).where(eq(olms.olmId, olm.olmId)); if (olm.archived) {
await db
.update(olms)
.set({ archived: false })
.where(eq(olms.olmId, olm.olmId));
}
} catch (error) { } catch (error) {
logger.error("Error handling ping message", { error }); logger.error("Error handling ping message", { error });
} }

View File

@@ -55,6 +55,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return; return;
} }
if (client.blocked) {
logger.debug(`Client ${client.clientId} is blocked. Ignoring register.`);
return;
}
const [org] = await db const [org] = await db
.select() .select()
.from(orgs) .from(orgs)
@@ -112,18 +117,20 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if ( if (
(olmVersion && olm.version !== olmVersion) || (olmVersion && olm.version !== olmVersion) ||
(olmAgent && olm.agent !== olmAgent) (olmAgent && olm.agent !== olmAgent) ||
olm.archived
) { ) {
await db await db
.update(olms) .update(olms)
.set({ .set({
version: olmVersion, version: olmVersion,
agent: olmAgent agent: olmAgent,
archived: false
}) })
.where(eq(olms.olmId, olm.olmId)); .where(eq(olms.olmId, olm.olmId));
} }
if (client.pubKey !== publicKey) { if (client.pubKey !== publicKey || client.archived) {
logger.info( logger.info(
"Public key mismatch. Updating public key and clearing session info..." "Public key mismatch. Updating public key and clearing session info..."
); );
@@ -131,7 +138,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
await db await db
.update(clients) .update(clients)
.set({ .set({
pubKey: publicKey pubKey: publicKey,
archived: false,
}) })
.where(eq(clients.clientId, client.clientId)); .where(eq(clients.clientId, client.clientId));