Added threshold to search (#2899)
This commit is contained in:
@@ -607,6 +607,7 @@ class Memory(MemoryBase):
|
||||
run_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
threshold: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Searches for memories based on a query
|
||||
@@ -617,6 +618,7 @@ class Memory(MemoryBase):
|
||||
run_id (str, optional): ID of the run to search for. Defaults to None.
|
||||
limit (int, optional): Limit the number of results. Defaults to 100.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None..
|
||||
threshold (float, optional): Minimum score for a memory to be included in the results. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the search results, typically under a "results" key,
|
||||
@@ -633,11 +635,11 @@ class Memory(MemoryBase):
|
||||
capture_event(
|
||||
"mem0.search",
|
||||
self,
|
||||
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "sync"},
|
||||
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "sync", "threshold": threshold},
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._search_vector_store, query, effective_filters, limit)
|
||||
future_memories = executor.submit(self._search_vector_store, query, effective_filters, limit, threshold)
|
||||
future_graph_entities = (
|
||||
executor.submit(self.graph.search, query, effective_filters, limit) if self.enable_graph else None
|
||||
)
|
||||
@@ -664,7 +666,7 @@ class Memory(MemoryBase):
|
||||
else:
|
||||
return {"results": original_memories}
|
||||
|
||||
def _search_vector_store(self, query, filters, limit):
|
||||
def _search_vector_store(self, query, filters, limit, threshold: Optional[float] = None):
|
||||
embeddings = self.embedding_model.embed(query, "search")
|
||||
memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters)
|
||||
|
||||
@@ -697,6 +699,7 @@ class Memory(MemoryBase):
|
||||
if additional_metadata:
|
||||
memory_item_dict["metadata"] = additional_metadata
|
||||
|
||||
if threshold is None or mem.score >= threshold:
|
||||
original_memories.append(memory_item_dict)
|
||||
|
||||
return original_memories
|
||||
@@ -1442,6 +1445,7 @@ class AsyncMemory(MemoryBase):
|
||||
run_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
threshold: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Searches for memories based on a query
|
||||
@@ -1451,7 +1455,8 @@ class AsyncMemory(MemoryBase):
|
||||
agent_id (str, optional): ID of the agent to search for. Defaults to None.
|
||||
run_id (str, optional): ID of the run to search for. Defaults to None.
|
||||
limit (int, optional): Limit the number of results. Defaults to 100.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None..
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||
threshold (float, optional): Minimum score for a memory to be included in the results. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the search results, typically under a "results" key,
|
||||
@@ -1469,10 +1474,10 @@ class AsyncMemory(MemoryBase):
|
||||
capture_event(
|
||||
"mem0.search",
|
||||
self,
|
||||
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "async"},
|
||||
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "async", "threshold": threshold},
|
||||
)
|
||||
|
||||
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit))
|
||||
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit, threshold))
|
||||
|
||||
graph_task = None
|
||||
if self.enable_graph:
|
||||
@@ -1502,7 +1507,7 @@ class AsyncMemory(MemoryBase):
|
||||
else:
|
||||
return {"results": original_memories}
|
||||
|
||||
async def _search_vector_store(self, query, filters, limit):
|
||||
async def _search_vector_store(self, query, filters, limit, threshold: Optional[float] = None):
|
||||
embeddings = await asyncio.to_thread(self.embedding_model.embed, query, "search")
|
||||
memories = await asyncio.to_thread(
|
||||
self.vector_store.search, query=query, vectors=embeddings, limit=limit, filters=filters
|
||||
@@ -1537,6 +1542,7 @@ class AsyncMemory(MemoryBase):
|
||||
if additional_metadata:
|
||||
memory_item_dict["metadata"] = additional_metadata
|
||||
|
||||
if threshold is None or mem.score >= threshold:
|
||||
original_memories.append(memory_item_dict)
|
||||
|
||||
return original_memories
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "mem0ai"
|
||||
version = "0.1.104"
|
||||
version = "0.1.105"
|
||||
description = "Long-term memory for AI Agents"
|
||||
authors = [
|
||||
{ name = "Mem0", email = "founders@mem0.ai" }
|
||||
|
||||
Reference in New Issue
Block a user