Add limit in get_all and search for Graph (#1920)

This commit is contained in:
Dev Khant
2024-09-28 01:51:38 +05:30
committed by GitHub
parent aaf8e6e7ff
commit 68c7355f47
3 changed files with 18 additions and 11 deletions

View File

@@ -160,7 +160,7 @@ class MemoryGraph:
return returned_entities
def _search(self, query, filters):
def _search(self, query, filters, limit):
_tools = [SEARCH_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [SEARCH_STRUCT_TOOL]
@@ -219,24 +219,27 @@ class MemoryGraph:
MATCH (m)-[r]->(n)
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
return result_relations
def search(self, query, filters):
def search(self, query, filters, limit):
"""
Search for memories and related graph data.
Args:
query (str): Query to search for.
filters (dict): A dictionary containing filters to be applied during the search.
limit (int): The maximum number of nodes and relationships to retrieve.
Returns:
dict: A dictionary containing:
@@ -244,7 +247,7 @@ class MemoryGraph:
- "entities": List of related graph data based on the query.
"""
search_output = self._search(query, filters)
search_output = self._search(query, filters, limit)
if not search_output:
return []
@@ -271,12 +274,13 @@ class MemoryGraph:
params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params)
def get_all(self, filters):
def get_all(self, filters, limit):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
Args:
filters (dict): A dictionary containing filters to be applied during the retrieval.
limit (int): The maximum number of nodes and relationships to retrieve.
Returns:
list: A list of dictionaries, each containing:
- 'contexts': The base data store response for each memory.
@@ -287,8 +291,9 @@ class MemoryGraph:
query = """
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
results = self.graph.query(query, params={"user_id": filters["user_id"]})
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
final_results = []
for result in results:

View File

@@ -286,7 +286,7 @@ class Memory(MemoryBase):
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = (
executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None
)
all_memories = future_memories.result()
@@ -374,7 +374,7 @@ class Memory(MemoryBase):
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = (
executor.submit(self.graph.search, query, filters)
executor.submit(self.graph.search, query, filters, limit)
if self.version == "v1.1" and self.enable_graph
else None
)