From 3acd9e20da4b26d7899d4855b7b7efed2d7af3f5 Mon Sep 17 00:00:00 2001 From: Saket Aryan Date: Tue, 18 Mar 2025 04:00:40 +0530 Subject: [PATCH] Fix Redis Search (#2392) --- mem0-ts/package.json | 2 +- mem0-ts/src/oss/src/vector_stores/redis.ts | 68 +++++++++++++++------- 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/mem0-ts/package.json b/mem0-ts/package.json index 9c04f2ee..2b0e6fc4 100644 --- a/mem0-ts/package.json +++ b/mem0-ts/package.json @@ -1,6 +1,6 @@ { "name": "mem0ai", - "version": "2.1.4", + "version": "2.1.5", "description": "The Memory Layer For Your AI Apps", "main": "./dist/index.js", "module": "./dist/index.mjs", diff --git a/mem0-ts/src/oss/src/vector_stores/redis.ts b/mem0-ts/src/oss/src/vector_stores/redis.ts index 7b0188f6..796199dc 100644 --- a/mem0-ts/src/oss/src/vector_stores/redis.ts +++ b/mem0-ts/src/oss/src/vector_stores/redis.ts @@ -62,7 +62,7 @@ interface RedisDocument { run_id?: string; user_id?: string; metadata?: string; - vector_score?: number; + __vector_score?: number; }; } @@ -108,6 +108,30 @@ const EXCLUDED_KEYS = new Set([ "updated_at", ]); +// Utility function to convert object keys to snake_case +function toSnakeCase(obj: Record): Record { + if (typeof obj !== "object" || obj === null) return obj; + + return Object.fromEntries( + Object.entries(obj).map(([key, value]) => [ + key.replace(/[A-Z]/g, (letter) => `_${letter.toLowerCase()}`), + value, + ]), + ); +} + +// Utility function to convert object keys to camelCase +function toCamelCase(obj: Record): Record { + if (typeof obj !== "object" || obj === null) return obj; + + return Object.fromEntries( + Object.entries(obj).map(([key, value]) => [ + key.replace(/_([a-z])/g, (_, letter) => letter.toUpperCase()), + value, + ]), + ); +} + export class RedisDB implements VectorStore { private client: RedisClientType< RedisDefaultModules & RedisModules & RedisFunctions & RedisScripts @@ -272,7 +296,7 @@ export class RedisDB implements VectorStore { payloads: Record[], ): Promise { const data = vectors.map((vector, idx) => { - const payload = payloads[idx]; + const payload = toSnakeCase(payloads[idx]); const id = ids[idx]; // Create entry with required fields @@ -322,8 +346,9 @@ export class RedisDB implements VectorStore { limit: number = 5, filters?: SearchFilters, ): Promise { - const filterExpr = filters - ? Object.entries(filters) + const snakeFilters = filters ? toSnakeCase(filters) : undefined; + const filterExpr = snakeFilters + ? Object.entries(snakeFilters) .filter(([_, value]) => value !== null) .map(([key, value]) => `@${key}:{${value}}`) .join(" ") @@ -344,8 +369,9 @@ export class RedisDB implements VectorStore { "memory", "metadata", "created_at", + "__vector_score", ], - SORTBY: "vector_score", + SORTBY: "__vector_score", DIALECT: 2, LIMIT: { from: 0, @@ -356,12 +382,12 @@ export class RedisDB implements VectorStore { try { const results = (await this.client.ft.search( this.indexName, - `${filterExpr} =>[KNN ${limit} @embedding $vec AS vector_score]`, + `${filterExpr} =>[KNN ${limit} @embedding $vec AS __vector_score]`, searchOptions, )) as unknown as RedisSearchResult; return results.documents.map((doc) => { - const payload = { + const resultPayload = { hash: doc.value.hash, data: doc.value.memory, created_at: new Date(parseInt(doc.value.created_at)).toISOString(), @@ -376,8 +402,8 @@ export class RedisDB implements VectorStore { return { id: doc.value.memory_id, - payload, - score: doc.value.vector_score, + payload: toCamelCase(resultPayload), + score: Number(doc.value.__vector_score) ?? 0, }; }); } catch (error) { @@ -493,26 +519,27 @@ export class RedisDB implements VectorStore { vector: number[], payload: Record, ): Promise { + const snakePayload = toSnakeCase(payload); const entry: Record = { memory_id: vectorId, - hash: payload.hash, - memory: payload.data, - created_at: new Date(payload.created_at).getTime(), - updated_at: new Date(payload.updated_at).getTime(), + hash: snakePayload.hash, + memory: snakePayload.data, + created_at: new Date(snakePayload.created_at).getTime(), + updated_at: new Date(snakePayload.updated_at).getTime(), embedding: Buffer.from(new Float32Array(vector).buffer), }; // Add optional fields ["agent_id", "run_id", "user_id"].forEach((field) => { - if (field in payload) { - entry[field] = payload[field]; + if (field in snakePayload) { + entry[field] = snakePayload[field]; } }); // Add metadata excluding specific keys entry.metadata = JSON.stringify( Object.fromEntries( - Object.entries(payload).filter(([key]) => !EXCLUDED_KEYS.has(key)), + Object.entries(snakePayload).filter(([key]) => !EXCLUDED_KEYS.has(key)), ), ); @@ -557,8 +584,9 @@ export class RedisDB implements VectorStore { filters?: SearchFilters, limit: number = 100, ): Promise<[VectorStoreResult[], number]> { - const filterExpr = filters - ? Object.entries(filters) + const snakeFilters = filters ? toSnakeCase(filters) : undefined; + const filterExpr = snakeFilters + ? Object.entries(snakeFilters) .filter(([_, value]) => value !== null) .map(([key, value]) => `@${key}:{${value}}`) .join(" ") @@ -581,7 +609,7 @@ export class RedisDB implements VectorStore { const items = results.documents.map((doc) => ({ id: doc.value.memory_id, - payload: { + payload: toCamelCase({ hash: doc.value.hash, data: doc.value.memory, created_at: new Date(parseInt(doc.value.created_at)).toISOString(), @@ -592,7 +620,7 @@ export class RedisDB implements VectorStore { ...(doc.value.run_id && { run_id: doc.value.run_id }), ...(doc.value.user_id && { user_id: doc.value.user_id }), ...JSON.parse(doc.value.metadata || "{}"), - }, + }), })); return [items, results.total];