Add limit in get_all and search for Graph (#1920)

This commit is contained in:
Dev Khant
2024-09-28 01:51:38 +05:30
committed by GitHub
parent aaf8e6e7ff
commit 68c7355f47
3 changed files with 18 additions and 11 deletions

View File

@@ -160,7 +160,7 @@ class MemoryGraph:
return returned_entities
def _search(self, query, filters):
def _search(self, query, filters, limit):
_tools = [SEARCH_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [SEARCH_STRUCT_TOOL]
@@ -219,24 +219,27 @@ class MemoryGraph:
MATCH (m)-[r]->(n)
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
return result_relations
def search(self, query, filters):
def search(self, query, filters, limit):
"""
Search for memories and related graph data.
Args:
query (str): Query to search for.
filters (dict): A dictionary containing filters to be applied during the search.
limit (int): The maximum number of nodes and relationships to retrieve.
Returns:
dict: A dictionary containing:
@@ -244,7 +247,7 @@ class MemoryGraph:
- "entities": List of related graph data based on the query.
"""
search_output = self._search(query, filters)
search_output = self._search(query, filters, limit)
if not search_output:
return []
@@ -271,12 +274,13 @@ class MemoryGraph:
params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params)
def get_all(self, filters):
def get_all(self, filters, limit):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
Args:
filters (dict): A dictionary containing filters to be applied during the retrieval.
limit (int): The maximum number of nodes and relationships to retrieve.
Returns:
list: A list of dictionaries, each containing:
- 'contexts': The base data store response for each memory.
@@ -287,8 +291,9 @@ class MemoryGraph:
query = """
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
results = self.graph.query(query, params={"user_id": filters["user_id"]})
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
final_results = []
for result in results:

View File

@@ -286,7 +286,7 @@ class Memory(MemoryBase):
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = (
executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None
)
all_memories = future_memories.result()
@@ -374,7 +374,7 @@ class Memory(MemoryBase):
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = (
executor.submit(self.graph.search, query, filters)
executor.submit(self.graph.search, query, filters, limit)
if self.version == "v1.1" and self.enable_graph
else None
)

View File

@@ -1,8 +1,10 @@
import pytest
import os
from unittest.mock import Mock, patch
from mem0.memory.main import Memory
import pytest
from mem0.configs.base import MemoryConfig
from mem0.memory.main import Memory
@pytest.fixture(autouse=True)
@@ -119,7 +121,7 @@ def test_search(memory_instance, version, enable_graph):
memory_instance.embedding_model.embed.assert_called_once_with("test query")
if enable_graph:
memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"})
memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"}, 100)
else:
memory_instance.graph.search.assert_not_called()
@@ -217,6 +219,6 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100)
if enable_graph:
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"})
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100)
else:
memory_instance.graph.get_all.assert_not_called()