From be37fca1bba588b040151bf1a1766ac1f375a57e Mon Sep 17 00:00:00 2001 From: Prateek Chhikara <46902268+prateekchhikara@users.noreply.github.com> Date: Tue, 3 Jun 2025 03:58:21 -0600 Subject: [PATCH] Added threshold to search (#2899) --- mem0/memory/main.py | 28 +++++++++++++++++----------- pyproject.toml | 2 +- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 9aee0acb..f1b26f4d 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -607,6 +607,7 @@ class Memory(MemoryBase): run_id: Optional[str] = None, limit: int = 100, filters: Optional[Dict[str, Any]] = None, + threshold: Optional[float] = None, ): """ Searches for memories based on a query @@ -617,6 +618,7 @@ class Memory(MemoryBase): run_id (str, optional): ID of the run to search for. Defaults to None. limit (int, optional): Limit the number of results. Defaults to 100. filters (dict, optional): Filters to apply to the search. Defaults to None.. + threshold (float, optional): Minimum score for a memory to be included in the results. Defaults to None. Returns: dict: A dictionary containing the search results, typically under a "results" key, @@ -633,11 +635,11 @@ class Memory(MemoryBase): capture_event( "mem0.search", self, - {"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "sync"}, + {"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "sync", "threshold": threshold}, ) with concurrent.futures.ThreadPoolExecutor() as executor: - future_memories = executor.submit(self._search_vector_store, query, effective_filters, limit) + future_memories = executor.submit(self._search_vector_store, query, effective_filters, limit, threshold) future_graph_entities = ( executor.submit(self.graph.search, query, effective_filters, limit) if self.enable_graph else None ) @@ -664,7 +666,7 @@ class Memory(MemoryBase): else: return {"results": original_memories} - def _search_vector_store(self, query, filters, limit): + def _search_vector_store(self, query, filters, limit, threshold: Optional[float] = None): embeddings = self.embedding_model.embed(query, "search") memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters) @@ -696,8 +698,9 @@ class Memory(MemoryBase): additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys} if additional_metadata: memory_item_dict["metadata"] = additional_metadata - - original_memories.append(memory_item_dict) + + if threshold is None or mem.score >= threshold: + original_memories.append(memory_item_dict) return original_memories @@ -1442,6 +1445,7 @@ class AsyncMemory(MemoryBase): run_id: Optional[str] = None, limit: int = 100, filters: Optional[Dict[str, Any]] = None, + threshold: Optional[float] = None, ): """ Searches for memories based on a query @@ -1451,7 +1455,8 @@ class AsyncMemory(MemoryBase): agent_id (str, optional): ID of the agent to search for. Defaults to None. run_id (str, optional): ID of the run to search for. Defaults to None. limit (int, optional): Limit the number of results. Defaults to 100. - filters (dict, optional): Filters to apply to the search. Defaults to None.. + filters (dict, optional): Filters to apply to the search. Defaults to None. + threshold (float, optional): Minimum score for a memory to be included in the results. Defaults to None. Returns: dict: A dictionary containing the search results, typically under a "results" key, @@ -1469,10 +1474,10 @@ class AsyncMemory(MemoryBase): capture_event( "mem0.search", self, - {"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "async"}, + {"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "async", "threshold": threshold}, ) - vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit)) + vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit, threshold)) graph_task = None if self.enable_graph: @@ -1502,7 +1507,7 @@ class AsyncMemory(MemoryBase): else: return {"results": original_memories} - async def _search_vector_store(self, query, filters, limit): + async def _search_vector_store(self, query, filters, limit, threshold: Optional[float] = None): embeddings = await asyncio.to_thread(self.embedding_model.embed, query, "search") memories = await asyncio.to_thread( self.vector_store.search, query=query, vectors=embeddings, limit=limit, filters=filters @@ -1536,8 +1541,9 @@ class AsyncMemory(MemoryBase): additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys} if additional_metadata: memory_item_dict["metadata"] = additional_metadata - - original_memories.append(memory_item_dict) + + if threshold is None or mem.score >= threshold: + original_memories.append(memory_item_dict) return original_memories diff --git a/pyproject.toml b/pyproject.toml index 62251f0f..b52f40d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mem0ai" -version = "0.1.104" +version = "0.1.105" description = "Long-term memory for AI Agents" authors = [ { name = "Mem0", email = "founders@mem0.ai" }