diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index c14db70b..987b923e 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -160,7 +160,7 @@ class MemoryGraph: return returned_entities - def _search(self, query, filters): + def _search(self, query, filters, limit): _tools = [SEARCH_TOOL] if self.llm_provider in ["azure_openai_structured", "openai_structured"]: _tools = [SEARCH_STRUCT_TOOL] @@ -219,24 +219,27 @@ class MemoryGraph: MATCH (m)-[r]->(n) RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity ORDER BY similarity DESC + LIMIT $limit """ params = { "n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"], + "limit": limit, } ans = self.graph.query(cypher_query, params=params) result_relations.extend(ans) return result_relations - def search(self, query, filters): + def search(self, query, filters, limit): """ Search for memories and related graph data. Args: query (str): Query to search for. filters (dict): A dictionary containing filters to be applied during the search. + limit (int): The maximum number of nodes and relationships to retrieve. Returns: dict: A dictionary containing: @@ -244,7 +247,7 @@ class MemoryGraph: - "entities": List of related graph data based on the query. """ - search_output = self._search(query, filters) + search_output = self._search(query, filters, limit) if not search_output: return [] @@ -271,12 +274,13 @@ class MemoryGraph: params = {"user_id": filters["user_id"]} self.graph.query(cypher, params=params) - def get_all(self, filters): + def get_all(self, filters, limit): """ Retrieves all nodes and relationships from the graph database based on optional filtering criteria. Args: filters (dict): A dictionary containing filters to be applied during the retrieval. + limit (int): The maximum number of nodes and relationships to retrieve. Returns: list: A list of dictionaries, each containing: - 'contexts': The base data store response for each memory. @@ -287,8 +291,9 @@ class MemoryGraph: query = """ MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id}) RETURN n.name AS source, type(r) AS relationship, m.name AS target + LIMIT $limit """ - results = self.graph.query(query, params={"user_id": filters["user_id"]}) + results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit}) final_results = [] for result in results: diff --git a/mem0/memory/main.py b/mem0/memory/main.py index d97c344f..7d82c632 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -286,7 +286,7 @@ class Memory(MemoryBase): with concurrent.futures.ThreadPoolExecutor() as executor: future_memories = executor.submit(self._get_all_from_vector_store, filters, limit) future_graph_entities = ( - executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None + executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None ) all_memories = future_memories.result() @@ -374,7 +374,7 @@ class Memory(MemoryBase): with concurrent.futures.ThreadPoolExecutor() as executor: future_memories = executor.submit(self._search_vector_store, query, filters, limit) future_graph_entities = ( - executor.submit(self.graph.search, query, filters) + executor.submit(self.graph.search, query, filters, limit) if self.version == "v1.1" and self.enable_graph else None ) diff --git a/tests/test_main.py b/tests/test_main.py index 8ed22245..fd667b4f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,10 @@ -import pytest import os from unittest.mock import Mock, patch -from mem0.memory.main import Memory + +import pytest + from mem0.configs.base import MemoryConfig +from mem0.memory.main import Memory @pytest.fixture(autouse=True) @@ -119,7 +121,7 @@ def test_search(memory_instance, version, enable_graph): memory_instance.embedding_model.embed.assert_called_once_with("test query") if enable_graph: - memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"}) + memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"}, 100) else: memory_instance.graph.search.assert_not_called() @@ -217,6 +219,6 @@ def test_get_all(memory_instance, version, enable_graph, expected_result): memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100) if enable_graph: - memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}) + memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100) else: memory_instance.graph.get_all.assert_not_called()