Added threshold to search (#2899)
This commit is contained in:
@@ -607,6 +607,7 @@ class Memory(MemoryBase):
|
|||||||
run_id: Optional[str] = None,
|
run_id: Optional[str] = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
filters: Optional[Dict[str, Any]] = None,
|
filters: Optional[Dict[str, Any]] = None,
|
||||||
|
threshold: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Searches for memories based on a query
|
Searches for memories based on a query
|
||||||
@@ -617,6 +618,7 @@ class Memory(MemoryBase):
|
|||||||
run_id (str, optional): ID of the run to search for. Defaults to None.
|
run_id (str, optional): ID of the run to search for. Defaults to None.
|
||||||
limit (int, optional): Limit the number of results. Defaults to 100.
|
limit (int, optional): Limit the number of results. Defaults to 100.
|
||||||
filters (dict, optional): Filters to apply to the search. Defaults to None..
|
filters (dict, optional): Filters to apply to the search. Defaults to None..
|
||||||
|
threshold (float, optional): Minimum score for a memory to be included in the results. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the search results, typically under a "results" key,
|
dict: A dictionary containing the search results, typically under a "results" key,
|
||||||
@@ -633,11 +635,11 @@ class Memory(MemoryBase):
|
|||||||
capture_event(
|
capture_event(
|
||||||
"mem0.search",
|
"mem0.search",
|
||||||
self,
|
self,
|
||||||
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "sync"},
|
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "sync", "threshold": threshold},
|
||||||
)
|
)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
future_memories = executor.submit(self._search_vector_store, query, effective_filters, limit)
|
future_memories = executor.submit(self._search_vector_store, query, effective_filters, limit, threshold)
|
||||||
future_graph_entities = (
|
future_graph_entities = (
|
||||||
executor.submit(self.graph.search, query, effective_filters, limit) if self.enable_graph else None
|
executor.submit(self.graph.search, query, effective_filters, limit) if self.enable_graph else None
|
||||||
)
|
)
|
||||||
@@ -664,7 +666,7 @@ class Memory(MemoryBase):
|
|||||||
else:
|
else:
|
||||||
return {"results": original_memories}
|
return {"results": original_memories}
|
||||||
|
|
||||||
def _search_vector_store(self, query, filters, limit):
|
def _search_vector_store(self, query, filters, limit, threshold: Optional[float] = None):
|
||||||
embeddings = self.embedding_model.embed(query, "search")
|
embeddings = self.embedding_model.embed(query, "search")
|
||||||
memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters)
|
memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters)
|
||||||
|
|
||||||
@@ -696,8 +698,9 @@ class Memory(MemoryBase):
|
|||||||
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
||||||
if additional_metadata:
|
if additional_metadata:
|
||||||
memory_item_dict["metadata"] = additional_metadata
|
memory_item_dict["metadata"] = additional_metadata
|
||||||
|
|
||||||
original_memories.append(memory_item_dict)
|
if threshold is None or mem.score >= threshold:
|
||||||
|
original_memories.append(memory_item_dict)
|
||||||
|
|
||||||
return original_memories
|
return original_memories
|
||||||
|
|
||||||
@@ -1442,6 +1445,7 @@ class AsyncMemory(MemoryBase):
|
|||||||
run_id: Optional[str] = None,
|
run_id: Optional[str] = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
filters: Optional[Dict[str, Any]] = None,
|
filters: Optional[Dict[str, Any]] = None,
|
||||||
|
threshold: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Searches for memories based on a query
|
Searches for memories based on a query
|
||||||
@@ -1451,7 +1455,8 @@ class AsyncMemory(MemoryBase):
|
|||||||
agent_id (str, optional): ID of the agent to search for. Defaults to None.
|
agent_id (str, optional): ID of the agent to search for. Defaults to None.
|
||||||
run_id (str, optional): ID of the run to search for. Defaults to None.
|
run_id (str, optional): ID of the run to search for. Defaults to None.
|
||||||
limit (int, optional): Limit the number of results. Defaults to 100.
|
limit (int, optional): Limit the number of results. Defaults to 100.
|
||||||
filters (dict, optional): Filters to apply to the search. Defaults to None..
|
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
threshold (float, optional): Minimum score for a memory to be included in the results. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the search results, typically under a "results" key,
|
dict: A dictionary containing the search results, typically under a "results" key,
|
||||||
@@ -1469,10 +1474,10 @@ class AsyncMemory(MemoryBase):
|
|||||||
capture_event(
|
capture_event(
|
||||||
"mem0.search",
|
"mem0.search",
|
||||||
self,
|
self,
|
||||||
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "async"},
|
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "async", "threshold": threshold},
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit))
|
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit, threshold))
|
||||||
|
|
||||||
graph_task = None
|
graph_task = None
|
||||||
if self.enable_graph:
|
if self.enable_graph:
|
||||||
@@ -1502,7 +1507,7 @@ class AsyncMemory(MemoryBase):
|
|||||||
else:
|
else:
|
||||||
return {"results": original_memories}
|
return {"results": original_memories}
|
||||||
|
|
||||||
async def _search_vector_store(self, query, filters, limit):
|
async def _search_vector_store(self, query, filters, limit, threshold: Optional[float] = None):
|
||||||
embeddings = await asyncio.to_thread(self.embedding_model.embed, query, "search")
|
embeddings = await asyncio.to_thread(self.embedding_model.embed, query, "search")
|
||||||
memories = await asyncio.to_thread(
|
memories = await asyncio.to_thread(
|
||||||
self.vector_store.search, query=query, vectors=embeddings, limit=limit, filters=filters
|
self.vector_store.search, query=query, vectors=embeddings, limit=limit, filters=filters
|
||||||
@@ -1536,8 +1541,9 @@ class AsyncMemory(MemoryBase):
|
|||||||
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
||||||
if additional_metadata:
|
if additional_metadata:
|
||||||
memory_item_dict["metadata"] = additional_metadata
|
memory_item_dict["metadata"] = additional_metadata
|
||||||
|
|
||||||
original_memories.append(memory_item_dict)
|
if threshold is None or mem.score >= threshold:
|
||||||
|
original_memories.append(memory_item_dict)
|
||||||
|
|
||||||
return original_memories
|
return original_memories
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "mem0ai"
|
name = "mem0ai"
|
||||||
version = "0.1.104"
|
version = "0.1.105"
|
||||||
description = "Long-term memory for AI Agents"
|
description = "Long-term memory for AI Agents"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Mem0", email = "founders@mem0.ai" }
|
{ name = "Mem0", email = "founders@mem0.ai" }
|
||||||
|
|||||||
Reference in New Issue
Block a user