Add limit in get_all and search for Graph (#1920)
This commit is contained in:
@@ -160,7 +160,7 @@ class MemoryGraph:
|
|||||||
|
|
||||||
return returned_entities
|
return returned_entities
|
||||||
|
|
||||||
def _search(self, query, filters):
|
def _search(self, query, filters, limit):
|
||||||
_tools = [SEARCH_TOOL]
|
_tools = [SEARCH_TOOL]
|
||||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||||
_tools = [SEARCH_STRUCT_TOOL]
|
_tools = [SEARCH_STRUCT_TOOL]
|
||||||
@@ -219,24 +219,27 @@ class MemoryGraph:
|
|||||||
MATCH (m)-[r]->(n)
|
MATCH (m)-[r]->(n)
|
||||||
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
|
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
|
||||||
ORDER BY similarity DESC
|
ORDER BY similarity DESC
|
||||||
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"n_embedding": n_embedding,
|
"n_embedding": n_embedding,
|
||||||
"threshold": self.threshold,
|
"threshold": self.threshold,
|
||||||
"user_id": filters["user_id"],
|
"user_id": filters["user_id"],
|
||||||
|
"limit": limit,
|
||||||
}
|
}
|
||||||
ans = self.graph.query(cypher_query, params=params)
|
ans = self.graph.query(cypher_query, params=params)
|
||||||
result_relations.extend(ans)
|
result_relations.extend(ans)
|
||||||
|
|
||||||
return result_relations
|
return result_relations
|
||||||
|
|
||||||
def search(self, query, filters):
|
def search(self, query, filters, limit):
|
||||||
"""
|
"""
|
||||||
Search for memories and related graph data.
|
Search for memories and related graph data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): Query to search for.
|
query (str): Query to search for.
|
||||||
filters (dict): A dictionary containing filters to be applied during the search.
|
filters (dict): A dictionary containing filters to be applied during the search.
|
||||||
|
limit (int): The maximum number of nodes and relationships to retrieve.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing:
|
dict: A dictionary containing:
|
||||||
@@ -244,7 +247,7 @@ class MemoryGraph:
|
|||||||
- "entities": List of related graph data based on the query.
|
- "entities": List of related graph data based on the query.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
search_output = self._search(query, filters)
|
search_output = self._search(query, filters, limit)
|
||||||
|
|
||||||
if not search_output:
|
if not search_output:
|
||||||
return []
|
return []
|
||||||
@@ -271,12 +274,13 @@ class MemoryGraph:
|
|||||||
params = {"user_id": filters["user_id"]}
|
params = {"user_id": filters["user_id"]}
|
||||||
self.graph.query(cypher, params=params)
|
self.graph.query(cypher, params=params)
|
||||||
|
|
||||||
def get_all(self, filters):
|
def get_all(self, filters, limit):
|
||||||
"""
|
"""
|
||||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filters (dict): A dictionary containing filters to be applied during the retrieval.
|
filters (dict): A dictionary containing filters to be applied during the retrieval.
|
||||||
|
limit (int): The maximum number of nodes and relationships to retrieve.
|
||||||
Returns:
|
Returns:
|
||||||
list: A list of dictionaries, each containing:
|
list: A list of dictionaries, each containing:
|
||||||
- 'contexts': The base data store response for each memory.
|
- 'contexts': The base data store response for each memory.
|
||||||
@@ -287,8 +291,9 @@ class MemoryGraph:
|
|||||||
query = """
|
query = """
|
||||||
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
|
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
|
||||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||||
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
results = self.graph.query(query, params={"user_id": filters["user_id"]})
|
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
|
||||||
|
|
||||||
final_results = []
|
final_results = []
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|||||||
@@ -286,7 +286,7 @@ class Memory(MemoryBase):
|
|||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
|
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
|
||||||
future_graph_entities = (
|
future_graph_entities = (
|
||||||
executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
|
executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None
|
||||||
)
|
)
|
||||||
|
|
||||||
all_memories = future_memories.result()
|
all_memories = future_memories.result()
|
||||||
@@ -374,7 +374,7 @@ class Memory(MemoryBase):
|
|||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
|
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
|
||||||
future_graph_entities = (
|
future_graph_entities = (
|
||||||
executor.submit(self.graph.search, query, filters)
|
executor.submit(self.graph.search, query, filters, limit)
|
||||||
if self.version == "v1.1" and self.enable_graph
|
if self.version == "v1.1" and self.enable_graph
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import pytest
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from mem0.memory.main import Memory
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mem0.configs.base import MemoryConfig
|
from mem0.configs.base import MemoryConfig
|
||||||
|
from mem0.memory.main import Memory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -119,7 +121,7 @@ def test_search(memory_instance, version, enable_graph):
|
|||||||
memory_instance.embedding_model.embed.assert_called_once_with("test query")
|
memory_instance.embedding_model.embed.assert_called_once_with("test query")
|
||||||
|
|
||||||
if enable_graph:
|
if enable_graph:
|
||||||
memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"})
|
memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"}, 100)
|
||||||
else:
|
else:
|
||||||
memory_instance.graph.search.assert_not_called()
|
memory_instance.graph.search.assert_not_called()
|
||||||
|
|
||||||
@@ -217,6 +219,6 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
|
|||||||
memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100)
|
memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100)
|
||||||
|
|
||||||
if enable_graph:
|
if enable_graph:
|
||||||
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"})
|
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100)
|
||||||
else:
|
else:
|
||||||
memory_instance.graph.get_all.assert_not_called()
|
memory_instance.graph.get_all.assert_not_called()
|
||||||
|
|||||||
Reference in New Issue
Block a user