Add limit in get_all and search for Graph (#1920)

This commit is contained in:
Dev Khant
2024-09-28 01:51:38 +05:30
committed by GitHub
parent aaf8e6e7ff
commit 68c7355f47
3 changed files with 18 additions and 11 deletions

View File

@@ -160,7 +160,7 @@ class MemoryGraph:
return returned_entities return returned_entities
def _search(self, query, filters): def _search(self, query, filters, limit):
_tools = [SEARCH_TOOL] _tools = [SEARCH_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]: if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [SEARCH_STRUCT_TOOL] _tools = [SEARCH_STRUCT_TOOL]
@@ -219,24 +219,27 @@ class MemoryGraph:
MATCH (m)-[r]->(n) MATCH (m)-[r]->(n)
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
ORDER BY similarity DESC ORDER BY similarity DESC
LIMIT $limit
""" """
params = { params = {
"n_embedding": n_embedding, "n_embedding": n_embedding,
"threshold": self.threshold, "threshold": self.threshold,
"user_id": filters["user_id"], "user_id": filters["user_id"],
"limit": limit,
} }
ans = self.graph.query(cypher_query, params=params) ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans) result_relations.extend(ans)
return result_relations return result_relations
def search(self, query, filters): def search(self, query, filters, limit):
""" """
Search for memories and related graph data. Search for memories and related graph data.
Args: Args:
query (str): Query to search for. query (str): Query to search for.
filters (dict): A dictionary containing filters to be applied during the search. filters (dict): A dictionary containing filters to be applied during the search.
limit (int): The maximum number of nodes and relationships to retrieve.
Returns: Returns:
dict: A dictionary containing: dict: A dictionary containing:
@@ -244,7 +247,7 @@ class MemoryGraph:
- "entities": List of related graph data based on the query. - "entities": List of related graph data based on the query.
""" """
search_output = self._search(query, filters) search_output = self._search(query, filters, limit)
if not search_output: if not search_output:
return [] return []
@@ -271,12 +274,13 @@ class MemoryGraph:
params = {"user_id": filters["user_id"]} params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params) self.graph.query(cypher, params=params)
def get_all(self, filters): def get_all(self, filters, limit):
""" """
Retrieves all nodes and relationships from the graph database based on optional filtering criteria. Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
Args: Args:
filters (dict): A dictionary containing filters to be applied during the retrieval. filters (dict): A dictionary containing filters to be applied during the retrieval.
limit (int): The maximum number of nodes and relationships to retrieve.
Returns: Returns:
list: A list of dictionaries, each containing: list: A list of dictionaries, each containing:
- 'contexts': The base data store response for each memory. - 'contexts': The base data store response for each memory.
@@ -287,8 +291,9 @@ class MemoryGraph:
query = """ query = """
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id}) MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
""" """
results = self.graph.query(query, params={"user_id": filters["user_id"]}) results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
final_results = [] final_results = []
for result in results: for result in results:

View File

@@ -286,7 +286,7 @@ class Memory(MemoryBase):
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit) future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = ( future_graph_entities = (
executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None
) )
all_memories = future_memories.result() all_memories = future_memories.result()
@@ -374,7 +374,7 @@ class Memory(MemoryBase):
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit) future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = ( future_graph_entities = (
executor.submit(self.graph.search, query, filters) executor.submit(self.graph.search, query, filters, limit)
if self.version == "v1.1" and self.enable_graph if self.version == "v1.1" and self.enable_graph
else None else None
) )

View File

@@ -1,8 +1,10 @@
import pytest
import os import os
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.memory.main import Memory
import pytest
from mem0.configs.base import MemoryConfig from mem0.configs.base import MemoryConfig
from mem0.memory.main import Memory
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -119,7 +121,7 @@ def test_search(memory_instance, version, enable_graph):
memory_instance.embedding_model.embed.assert_called_once_with("test query") memory_instance.embedding_model.embed.assert_called_once_with("test query")
if enable_graph: if enable_graph:
memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"}) memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"}, 100)
else: else:
memory_instance.graph.search.assert_not_called() memory_instance.graph.search.assert_not_called()
@@ -217,6 +219,6 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100) memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100)
if enable_graph: if enable_graph:
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}) memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100)
else: else:
memory_instance.graph.get_all.assert_not_called() memory_instance.graph.get_all.assert_not_called()