Added threshold to search (#2899)

This commit is contained in:
Prateek Chhikara
2025-06-03 03:58:21 -06:00
committed by GitHub
parent 849452cc93
commit be37fca1bb
2 changed files with 18 additions and 12 deletions

View File

@@ -607,6 +607,7 @@ class Memory(MemoryBase):
run_id: Optional[str] = None,
limit: int = 100,
filters: Optional[Dict[str, Any]] = None,
threshold: Optional[float] = None,
):
"""
Searches for memories based on a query
@@ -617,6 +618,7 @@ class Memory(MemoryBase):
run_id (str, optional): ID of the run to search for. Defaults to None.
limit (int, optional): Limit the number of results. Defaults to 100.
filters (dict, optional): Filters to apply to the search. Defaults to None..
threshold (float, optional): Minimum score for a memory to be included in the results. Defaults to None.
Returns:
dict: A dictionary containing the search results, typically under a "results" key,
@@ -633,11 +635,11 @@ class Memory(MemoryBase):
capture_event(
"mem0.search",
self,
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "sync"},
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "sync", "threshold": threshold},
)
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, effective_filters, limit)
future_memories = executor.submit(self._search_vector_store, query, effective_filters, limit, threshold)
future_graph_entities = (
executor.submit(self.graph.search, query, effective_filters, limit) if self.enable_graph else None
)
@@ -664,7 +666,7 @@ class Memory(MemoryBase):
else:
return {"results": original_memories}
def _search_vector_store(self, query, filters, limit):
def _search_vector_store(self, query, filters, limit, threshold: Optional[float] = None):
embeddings = self.embedding_model.embed(query, "search")
memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters)
@@ -697,6 +699,7 @@ class Memory(MemoryBase):
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
if threshold is None or mem.score >= threshold:
original_memories.append(memory_item_dict)
return original_memories
@@ -1442,6 +1445,7 @@ class AsyncMemory(MemoryBase):
run_id: Optional[str] = None,
limit: int = 100,
filters: Optional[Dict[str, Any]] = None,
threshold: Optional[float] = None,
):
"""
Searches for memories based on a query
@@ -1451,7 +1455,8 @@ class AsyncMemory(MemoryBase):
agent_id (str, optional): ID of the agent to search for. Defaults to None.
run_id (str, optional): ID of the run to search for. Defaults to None.
limit (int, optional): Limit the number of results. Defaults to 100.
filters (dict, optional): Filters to apply to the search. Defaults to None..
filters (dict, optional): Filters to apply to the search. Defaults to None.
threshold (float, optional): Minimum score for a memory to be included in the results. Defaults to None.
Returns:
dict: A dictionary containing the search results, typically under a "results" key,
@@ -1469,10 +1474,10 @@ class AsyncMemory(MemoryBase):
capture_event(
"mem0.search",
self,
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "async"},
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "async", "threshold": threshold},
)
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit))
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit, threshold))
graph_task = None
if self.enable_graph:
@@ -1502,7 +1507,7 @@ class AsyncMemory(MemoryBase):
else:
return {"results": original_memories}
async def _search_vector_store(self, query, filters, limit):
async def _search_vector_store(self, query, filters, limit, threshold: Optional[float] = None):
embeddings = await asyncio.to_thread(self.embedding_model.embed, query, "search")
memories = await asyncio.to_thread(
self.vector_store.search, query=query, vectors=embeddings, limit=limit, filters=filters
@@ -1537,6 +1542,7 @@ class AsyncMemory(MemoryBase):
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
if threshold is None or mem.score >= threshold:
original_memories.append(memory_item_dict)
return original_memories

View File

@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "mem0ai"
version = "0.1.104"
version = "0.1.105"
description = "Long-term memory for AI Agents"
authors = [
{ name = "Mem0", email = "founders@mem0.ai" }