Fix: Add Google Genai library support (#2941)
This commit is contained in:
@@ -80,8 +80,8 @@ class MemoryGraph:
|
||||
|
||||
# TODO: Batch queries with APOC plugin
|
||||
# TODO: Add more filter support
|
||||
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
|
||||
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
|
||||
deleted_entities = self._delete_entities(to_be_deleted, filters)
|
||||
added_entities = self._add_entities(to_be_added, filters, entity_type_map)
|
||||
|
||||
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
||||
|
||||
@@ -122,32 +122,35 @@ class MemoryGraph:
|
||||
return search_results
|
||||
|
||||
def delete_all(self, filters):
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id}})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
params = {"user_id": filters["user_id"]}
|
||||
if filters.get("agent_id"):
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id, agent_id: $agent_id}})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"]}
|
||||
else:
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id}})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
params = {"user_id": filters["user_id"]}
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
def get_all(self, filters, limit=100):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||
|
||||
Args:
|
||||
filters (dict): A dictionary containing filters to be applied during the retrieval.
|
||||
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
||||
Returns:
|
||||
list: A list of dictionaries, each containing:
|
||||
- 'contexts': The base data store response for each memory.
|
||||
- 'entities': A list of strings representing the nodes and relationships
|
||||
"""
|
||||
# return all nodes and relationships
|
||||
def get_all(self, filters, limit=100):
|
||||
agent_filter = ""
|
||||
params = {"user_id": filters["user_id"], "limit": limit}
|
||||
if filters.get("agent_id"):
|
||||
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||
params["agent_id"] = filters["agent_id"]
|
||||
|
||||
query = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
|
||||
WHERE 1=1 {agent_filter}
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
|
||||
results = self.graph.query(query, params=params)
|
||||
|
||||
final_results = []
|
||||
for result in results:
|
||||
@@ -163,6 +166,7 @@ class MemoryGraph:
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
def _retrieve_nodes_from_data(self, data, filters):
|
||||
"""Extracts all the entities mentioned in the query."""
|
||||
_tools = [EXTRACT_ENTITIES_TOOL]
|
||||
@@ -197,23 +201,27 @@ class MemoryGraph:
|
||||
return entity_type_map
|
||||
|
||||
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
||||
"""Eshtablish relations among the extracted nodes."""
|
||||
"""Establish relations among the extracted nodes."""
|
||||
|
||||
# Compose user identification string for prompt
|
||||
user_identity = f"user_id: {filters['user_id']}"
|
||||
if filters.get("agent_id"):
|
||||
user_identity += f", agent_id: {filters['agent_id']}"
|
||||
|
||||
if self.config.graph_store.custom_prompt:
|
||||
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
|
||||
# Add the custom prompt line if configured
|
||||
system_content = system_content.replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
),
|
||||
},
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
else:
|
||||
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
|
||||
},
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"},
|
||||
]
|
||||
|
||||
@@ -227,8 +235,8 @@ class MemoryGraph:
|
||||
)
|
||||
|
||||
entities = []
|
||||
if extracted_entities["tool_calls"]:
|
||||
entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
|
||||
if extracted_entities.get("tool_calls"):
|
||||
entities = extracted_entities["tool_calls"][0].get("arguments", {}).get("entities", [])
|
||||
|
||||
entities = self._remove_spaces_from_entities(entities)
|
||||
logger.debug(f"Extracted entities: {entities}")
|
||||
@@ -237,32 +245,43 @@ class MemoryGraph:
|
||||
def _search_graph_db(self, node_list, filters, limit=100):
|
||||
"""Search similar nodes among and their respective incoming and outgoing relations."""
|
||||
result_relations = []
|
||||
agent_filter = ""
|
||||
if filters.get("agent_id"):
|
||||
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||
|
||||
for node in node_list:
|
||||
n_embedding = self.embedding_model.embed(node)
|
||||
|
||||
cypher_query = f"""
|
||||
MATCH (n {self.node_label})
|
||||
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
|
||||
{agent_filter}
|
||||
WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility
|
||||
WHERE similarity >= $threshold
|
||||
CALL (n) {{
|
||||
MATCH (n)-[r]->(m)
|
||||
CALL {{
|
||||
MATCH (n)-[r]->(m)
|
||||
WHERE m.user_id = $user_id {agent_filter.replace("n.", "m.")}
|
||||
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id
|
||||
UNION
|
||||
MATCH (m)-[r]->(n)
|
||||
MATCH (m)-[r]->(n)
|
||||
WHERE m.user_id = $user_id {agent_filter.replace("n.", "m.")}
|
||||
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id
|
||||
}}
|
||||
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity //deduplicate
|
||||
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity
|
||||
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {
|
||||
"n_embedding": n_embedding,
|
||||
"threshold": self.threshold,
|
||||
"user_id": filters["user_id"],
|
||||
"limit": limit,
|
||||
}
|
||||
if filters.get("agent_id"):
|
||||
params["agent_id"] = filters["agent_id"]
|
||||
|
||||
ans = self.graph.query(cypher_query, params=params)
|
||||
result_relations.extend(ans)
|
||||
|
||||
@@ -271,7 +290,13 @@ class MemoryGraph:
|
||||
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
||||
"""Get the entities to be deleted from the search output."""
|
||||
search_output_string = format_entities(search_output)
|
||||
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
|
||||
|
||||
# Compose user identification string for prompt
|
||||
user_identity = f"user_id: {filters['user_id']}"
|
||||
if filters.get("agent_id"):
|
||||
user_identity += f", agent_id: {filters['agent_id']}"
|
||||
|
||||
system_prompt, user_prompt = get_delete_messages(search_output_string, data, user_identity)
|
||||
|
||||
_tools = [DELETE_MEMORY_TOOL_GRAPH]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
@@ -288,44 +313,59 @@ class MemoryGraph:
|
||||
)
|
||||
|
||||
to_be_deleted = []
|
||||
for item in memory_updates["tool_calls"]:
|
||||
if item["name"] == "delete_graph_memory":
|
||||
to_be_deleted.append(item["arguments"])
|
||||
# in case if it is not in the correct format
|
||||
for item in memory_updates.get("tool_calls", []):
|
||||
if item.get("name") == "delete_graph_memory":
|
||||
to_be_deleted.append(item.get("arguments"))
|
||||
# Clean entities formatting
|
||||
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
|
||||
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
||||
return to_be_deleted
|
||||
|
||||
def _delete_entities(self, to_be_deleted, user_id):
|
||||
def _delete_entities(self, to_be_deleted, filters):
|
||||
"""Delete the entities from the graph."""
|
||||
user_id = filters["user_id"]
|
||||
agent_id = filters.get("agent_id", None)
|
||||
results = []
|
||||
|
||||
for item in to_be_deleted:
|
||||
source = item["source"]
|
||||
destination = item["destination"]
|
||||
relationship = item["relationship"]
|
||||
|
||||
# Build the agent filter for the query
|
||||
agent_filter = ""
|
||||
params = {
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
if agent_id:
|
||||
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
# Delete the specific relationship between nodes
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
|
||||
-[r:{relationship}]->
|
||||
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
|
||||
WHERE 1=1 {agent_filter}
|
||||
DELETE r
|
||||
RETURN
|
||||
n.name AS source,
|
||||
m.name AS target,
|
||||
type(r) AS relationship
|
||||
"""
|
||||
params = {
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def _add_entities(self, to_be_added, user_id, entity_type_map):
|
||||
def _add_entities(self, to_be_added, filters, entity_type_map):
|
||||
"""Add the new entities to the graph. Merge the nodes if they already exist."""
|
||||
user_id = filters["user_id"]
|
||||
agent_id = filters.get("agent_id", None)
|
||||
results = []
|
||||
for item in to_be_added:
|
||||
# entities
|
||||
@@ -346,65 +386,80 @@ class MemoryGraph:
|
||||
dest_embedding = self.embedding_model.embed(destination)
|
||||
|
||||
# search for the nodes with the closest embeddings
|
||||
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
|
||||
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
|
||||
source_node_search_result = self._search_source_node(source_embedding, filters, threshold=0.9)
|
||||
destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=0.9)
|
||||
|
||||
# TODO: Create a cypher query and common params for all the cases
|
||||
if not destination_node_search_result and source_node_search_result:
|
||||
cypher = f"""
|
||||
MATCH (source)
|
||||
WHERE elementId(source) = $source_id
|
||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}})
|
||||
ON CREATE SET
|
||||
destination.created = timestamp(),
|
||||
destination.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET
|
||||
destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
WITH source, destination
|
||||
CALL db.create.setNodeVectorProperty(destination, 'embedding', $destination_embedding)
|
||||
WITH source, destination
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET
|
||||
r.mentions = coalesce(r.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
# Build destination MERGE properties
|
||||
merge_props = ["name: $destination_name", "user_id: $user_id"]
|
||||
if agent_id:
|
||||
merge_props.append("agent_id: $agent_id")
|
||||
merge_props_str = ", ".join(merge_props)
|
||||
|
||||
cypher = f"""
|
||||
MATCH (source)
|
||||
WHERE elementId(source) = $source_id
|
||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
MERGE (destination {destination_label} {{{merge_props_str}}})
|
||||
ON CREATE SET
|
||||
destination.created = timestamp(),
|
||||
destination.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET
|
||||
destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
WITH source, destination
|
||||
CALL db.create.setNodeVectorProperty(destination, 'embedding', $destination_embedding)
|
||||
WITH source, destination
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET
|
||||
r.mentions = coalesce(r.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||
"destination_name": destination,
|
||||
"destination_embedding": dest_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
elif destination_node_search_result and not source_node_search_result:
|
||||
# Build source MERGE properties
|
||||
merge_props = ["name: $source_name", "user_id: $user_id"]
|
||||
if agent_id:
|
||||
merge_props.append("agent_id: $agent_id")
|
||||
merge_props_str = ", ".join(merge_props)
|
||||
|
||||
cypher = f"""
|
||||
MATCH (destination)
|
||||
WHERE elementId(destination) = $destination_id
|
||||
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
WITH destination
|
||||
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
|
||||
ON CREATE SET
|
||||
source.created = timestamp(),
|
||||
source.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET
|
||||
source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source, destination
|
||||
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
|
||||
WITH source, destination
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET
|
||||
r.mentions = coalesce(r.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
MATCH (destination)
|
||||
WHERE elementId(destination) = $destination_id
|
||||
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
WITH destination
|
||||
MERGE (source {source_label} {{{merge_props_str}}})
|
||||
ON CREATE SET
|
||||
source.created = timestamp(),
|
||||
source.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET
|
||||
source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source, destination
|
||||
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
|
||||
WITH source, destination
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET
|
||||
r.mentions = coalesce(r.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
|
||||
params = {
|
||||
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
||||
@@ -412,53 +467,68 @@ class MemoryGraph:
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
elif source_node_search_result and destination_node_search_result:
|
||||
cypher = f"""
|
||||
MATCH (source)
|
||||
WHERE elementId(source) = $source_id
|
||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
MATCH (destination)
|
||||
WHERE elementId(destination) = $destination_id
|
||||
SET destination.mentions = coalesce(destination.mentions) + 1
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created_at = timestamp(),
|
||||
r.updated_at = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
|
||||
|
||||
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
MATCH (source)
|
||||
WHERE elementId(source) = $source_id
|
||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
MATCH (destination)
|
||||
WHERE elementId(destination) = $destination_id
|
||||
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created_at = timestamp(),
|
||||
r.updated_at = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
||||
"user_id": user_id,
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
else:
|
||||
# Build dynamic MERGE props for both source and destination
|
||||
source_props = ["name: $source_name", "user_id: $user_id"]
|
||||
dest_props = ["name: $dest_name", "user_id: $user_id"]
|
||||
if agent_id:
|
||||
source_props.append("agent_id: $agent_id")
|
||||
dest_props.append("agent_id: $agent_id")
|
||||
source_props_str = ", ".join(source_props)
|
||||
dest_props_str = ", ".join(dest_props)
|
||||
|
||||
cypher = f"""
|
||||
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
|
||||
ON CREATE SET source.created = timestamp(),
|
||||
source.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
|
||||
WITH source
|
||||
MERGE (destination {destination_label} {{name: $dest_name, user_id: $user_id}})
|
||||
ON CREATE SET destination.created = timestamp(),
|
||||
destination.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
WITH source, destination
|
||||
CALL db.create.setNodeVectorProperty(destination, 'embedding', $source_embedding)
|
||||
WITH source, destination
|
||||
MERGE (source)-[rel:{relationship}]->(destination)
|
||||
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
|
||||
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(rel) AS relationship, destination.name AS target
|
||||
"""
|
||||
MERGE (source {source_label} {{{source_props_str}}})
|
||||
ON CREATE SET source.created = timestamp(),
|
||||
source.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
|
||||
WITH source
|
||||
MERGE (destination {destination_label} {{{dest_props_str}}})
|
||||
ON CREATE SET destination.created = timestamp(),
|
||||
destination.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
WITH source, destination
|
||||
CALL db.create.setNodeVectorProperty(destination, 'embedding', $dest_embedding)
|
||||
WITH source, destination
|
||||
MERGE (source)-[rel:{relationship}]->(destination)
|
||||
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
|
||||
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(rel) AS relationship, destination.name AS target
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
@@ -466,6 +536,8 @@ class MemoryGraph:
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
return results
|
||||
@@ -477,11 +549,16 @@ class MemoryGraph:
|
||||
item["destination"] = item["destination"].lower().replace(" ", "_")
|
||||
return entity_list
|
||||
|
||||
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
||||
def _search_source_node(self, source_embedding, filters, threshold=0.9):
|
||||
agent_filter = ""
|
||||
if filters.get("agent_id"):
|
||||
agent_filter = "AND source_candidate.agent_id = $agent_id"
|
||||
|
||||
cypher = f"""
|
||||
MATCH (source_candidate {self.node_label})
|
||||
WHERE source_candidate.embedding IS NOT NULL
|
||||
AND source_candidate.user_id = $user_id
|
||||
{agent_filter}
|
||||
|
||||
WITH source_candidate,
|
||||
round(2 * vector.similarity.cosine(source_candidate.embedding, $source_embedding) - 1, 4) AS source_similarity // denormalize for backward compatibility
|
||||
@@ -496,18 +573,26 @@ class MemoryGraph:
|
||||
|
||||
params = {
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
"user_id": filters["user_id"],
|
||||
"threshold": threshold,
|
||||
}
|
||||
if filters.get("agent_id"):
|
||||
params["agent_id"] = filters["agent_id"]
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
|
||||
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
||||
|
||||
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
|
||||
agent_filter = ""
|
||||
if filters.get("agent_id"):
|
||||
agent_filter = "AND destination_candidate.agent_id = $agent_id"
|
||||
|
||||
cypher = f"""
|
||||
MATCH (destination_candidate {self.node_label})
|
||||
WHERE destination_candidate.embedding IS NOT NULL
|
||||
AND destination_candidate.user_id = $user_id
|
||||
{agent_filter}
|
||||
|
||||
WITH destination_candidate,
|
||||
round(2 * vector.similarity.cosine(destination_candidate.embedding, $destination_embedding) - 1, 4) AS destination_similarity // denormalize for backward compatibility
|
||||
@@ -520,11 +605,14 @@ class MemoryGraph:
|
||||
|
||||
RETURN elementId(destination_candidate)
|
||||
"""
|
||||
|
||||
params = {
|
||||
"destination_embedding": destination_embedding,
|
||||
"user_id": user_id,
|
||||
"user_id": filters["user_id"],
|
||||
"threshold": threshold,
|
||||
}
|
||||
if filters.get("agent_id"):
|
||||
params["agent_id"] = filters["agent_id"]
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
|
||||
@@ -118,11 +118,19 @@ class MemoryGraph:
|
||||
return search_results
|
||||
|
||||
def delete_all(self, filters):
|
||||
cypher = """
|
||||
MATCH (n {user_id: $user_id})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
params = {"user_id": filters["user_id"]}
|
||||
"""Delete all nodes and relationships for a user or specific agent."""
|
||||
if filters.get("agent_id"):
|
||||
cypher = """
|
||||
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"]}
|
||||
else:
|
||||
cypher = """
|
||||
MATCH (n:Entity {user_id: $user_id})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
params = {"user_id": filters["user_id"]}
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
def get_all(self, filters, limit=100):
|
||||
@@ -131,20 +139,31 @@ class MemoryGraph:
|
||||
|
||||
Args:
|
||||
filters (dict): A dictionary containing filters to be applied during the retrieval.
|
||||
Supports 'user_id' (required) and 'agent_id' (optional).
|
||||
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
||||
Returns:
|
||||
list: A list of dictionaries, each containing:
|
||||
- 'contexts': The base data store response for each memory.
|
||||
- 'entities': A list of strings representing the nodes and relationships
|
||||
- 'source': The source node name.
|
||||
- 'relationship': The relationship type.
|
||||
- 'target': The target node name.
|
||||
"""
|
||||
|
||||
# return all nodes and relationships
|
||||
query = """
|
||||
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id})
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
|
||||
# Build query based on whether agent_id is provided
|
||||
if filters.get("agent_id"):
|
||||
query = """
|
||||
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity {user_id: $user_id, agent_id: $agent_id})
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
LIMIT $limit
|
||||
"""
|
||||
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"], "limit": limit}
|
||||
else:
|
||||
query = """
|
||||
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id})
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
LIMIT $limit
|
||||
"""
|
||||
params = {"user_id": filters["user_id"], "limit": limit}
|
||||
|
||||
results = self.graph.query(query, params=params)
|
||||
|
||||
final_results = []
|
||||
for result in results:
|
||||
@@ -241,33 +260,65 @@ class MemoryGraph:
|
||||
for node in node_list:
|
||||
n_embedding = self.embedding_model.embed(node)
|
||||
|
||||
cypher_query = """
|
||||
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity)
|
||||
WHERE n.embedding IS NOT NULL
|
||||
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
||||
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
||||
YIELD node1, node2, similarity
|
||||
WITH node1, node2, similarity, r
|
||||
WHERE similarity >= $threshold
|
||||
RETURN node1.user_id AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.user_id AS destination, id(node2) AS destination_id, similarity
|
||||
UNION
|
||||
MATCH (n:Entity {user_id: $user_id})<-[r]-(m:Entity)
|
||||
WHERE n.embedding IS NOT NULL
|
||||
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
||||
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
||||
YIELD node1, node2, similarity
|
||||
WITH node1, node2, similarity, r
|
||||
WHERE similarity >= $threshold
|
||||
RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $limit;
|
||||
"""
|
||||
params = {
|
||||
"n_embedding": n_embedding,
|
||||
"threshold": self.threshold,
|
||||
"user_id": filters["user_id"],
|
||||
"limit": limit,
|
||||
}
|
||||
# Build query based on whether agent_id is provided
|
||||
if filters.get("agent_id"):
|
||||
cypher_query = """
|
||||
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity)
|
||||
WHERE n.embedding IS NOT NULL
|
||||
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
||||
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
||||
YIELD node1, node2, similarity
|
||||
WITH node1, node2, similarity, r
|
||||
WHERE similarity >= $threshold
|
||||
RETURN node1.name AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.name AS destination, id(node2) AS destination_id, similarity
|
||||
UNION
|
||||
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})<-[r]-(m:Entity)
|
||||
WHERE n.embedding IS NOT NULL
|
||||
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
||||
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
||||
YIELD node1, node2, similarity
|
||||
WITH node1, node2, similarity, r
|
||||
WHERE similarity >= $threshold
|
||||
RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $limit;
|
||||
"""
|
||||
params = {
|
||||
"n_embedding": n_embedding,
|
||||
"threshold": self.threshold,
|
||||
"user_id": filters["user_id"],
|
||||
"agent_id": filters["agent_id"],
|
||||
"limit": limit,
|
||||
}
|
||||
else:
|
||||
cypher_query = """
|
||||
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity)
|
||||
WHERE n.embedding IS NOT NULL
|
||||
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
||||
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
||||
YIELD node1, node2, similarity
|
||||
WITH node1, node2, similarity, r
|
||||
WHERE similarity >= $threshold
|
||||
RETURN node1.name AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.name AS destination, id(node2) AS destination_id, similarity
|
||||
UNION
|
||||
MATCH (n:Entity {user_id: $user_id})<-[r]-(m:Entity)
|
||||
WHERE n.embedding IS NOT NULL
|
||||
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
|
||||
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
|
||||
YIELD node1, node2, similarity
|
||||
WITH node1, node2, similarity, r
|
||||
WHERE similarity >= $threshold
|
||||
RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $limit;
|
||||
"""
|
||||
params = {
|
||||
"n_embedding": n_embedding,
|
||||
"threshold": self.threshold,
|
||||
"user_id": filters["user_id"],
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
ans = self.graph.query(cypher_query, params=params)
|
||||
result_relations.extend(ans)
|
||||
|
||||
@@ -300,38 +351,54 @@ class MemoryGraph:
|
||||
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
||||
return to_be_deleted
|
||||
|
||||
def _delete_entities(self, to_be_deleted, user_id):
|
||||
def _delete_entities(self, to_be_deleted, filters):
|
||||
"""Delete the entities from the graph."""
|
||||
user_id = filters["user_id"]
|
||||
agent_id = filters.get("agent_id", None)
|
||||
results = []
|
||||
|
||||
for item in to_be_deleted:
|
||||
source = item["source"]
|
||||
destination = item["destination"]
|
||||
relationship = item["relationship"]
|
||||
|
||||
# Build the agent filter for the query
|
||||
agent_filter = ""
|
||||
params = {
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
if agent_id:
|
||||
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
# Delete the specific relationship between nodes
|
||||
cypher = f"""
|
||||
MATCH (n:Entity {{name: $source_name, user_id: $user_id}})
|
||||
-[r:{relationship}]->
|
||||
(m {{name: $dest_name, user_id: $user_id}})
|
||||
(m:Entity {{name: $dest_name, user_id: $user_id}})
|
||||
WHERE 1=1 {agent_filter}
|
||||
DELETE r
|
||||
RETURN
|
||||
n.name AS source,
|
||||
m.name AS target,
|
||||
type(r) AS relationship
|
||||
"""
|
||||
params = {
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
# added Entity label to all nodes for vector search to work
|
||||
def _add_entities(self, to_be_added, user_id, entity_type_map):
|
||||
def _add_entities(self, to_be_added, filters, entity_type_map):
|
||||
"""Add the new entities to the graph. Merge the nodes if they already exist."""
|
||||
user_id = filters["user_id"]
|
||||
agent_id = filters.get("agent_id", None)
|
||||
results = []
|
||||
|
||||
for item in to_be_added:
|
||||
# entities
|
||||
source = item["source"]
|
||||
@@ -346,18 +413,21 @@ class MemoryGraph:
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
dest_embedding = self.embedding_model.embed(destination)
|
||||
|
||||
# search for the nodes with the closest embeddings; this is basically
|
||||
# comparison of one embedding to all embeddings in a graph -> vector
|
||||
# search with cosine similarity metric
|
||||
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
|
||||
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
|
||||
# search for the nodes with the closest embeddings
|
||||
source_node_search_result = self._search_source_node(source_embedding, filters, threshold=0.9)
|
||||
destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=0.9)
|
||||
|
||||
# Prepare agent_id for node creation
|
||||
agent_id_clause = ""
|
||||
if agent_id:
|
||||
agent_id_clause = ", agent_id: $agent_id"
|
||||
|
||||
# TODO: Create a cypher query and common params for all the cases
|
||||
if not destination_node_search_result and source_node_search_result:
|
||||
cypher = f"""
|
||||
MATCH (source:Entity)
|
||||
WHERE id(source) = $source_id
|
||||
MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id}})
|
||||
MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id{agent_id_clause}}})
|
||||
ON CREATE SET
|
||||
destination.created = timestamp(),
|
||||
destination.embedding = $destination_embedding,
|
||||
@@ -374,11 +444,14 @@ class MemoryGraph:
|
||||
"destination_embedding": dest_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
elif destination_node_search_result and not source_node_search_result:
|
||||
cypher = f"""
|
||||
MATCH (destination:Entity)
|
||||
WHERE id(destination) = $destination_id
|
||||
MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id}})
|
||||
MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
|
||||
ON CREATE SET
|
||||
source.created = timestamp(),
|
||||
source.embedding = $source_embedding,
|
||||
@@ -395,6 +468,9 @@ class MemoryGraph:
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
elif source_node_search_result and destination_node_search_result:
|
||||
cypher = f"""
|
||||
MATCH (source:Entity)
|
||||
@@ -412,12 +488,15 @@ class MemoryGraph:
|
||||
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
|
||||
"user_id": user_id,
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
else:
|
||||
cypher = f"""
|
||||
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id}})
|
||||
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
|
||||
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding, n:Entity
|
||||
ON MATCH SET n.embedding = $source_embedding
|
||||
MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id}})
|
||||
MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id{agent_id_clause}}})
|
||||
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding, m:Entity
|
||||
ON MATCH SET m.embedding = $dest_embedding
|
||||
MERGE (n)-[rel:{relationship}]->(m)
|
||||
@@ -431,6 +510,9 @@ class MemoryGraph:
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
if agent_id:
|
||||
params["agent_id"] = agent_id
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
return results
|
||||
@@ -442,37 +524,80 @@ class MemoryGraph:
|
||||
item["destination"] = item["destination"].lower().replace(" ", "_")
|
||||
return entity_list
|
||||
|
||||
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
||||
cypher = """
|
||||
CALL vector_search.search("memzero", 1, $source_embedding)
|
||||
YIELD distance, node, similarity
|
||||
WITH node AS source_candidate, similarity
|
||||
WHERE source_candidate.user_id = $user_id AND similarity >= $threshold
|
||||
RETURN id(source_candidate);
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
"threshold": threshold,
|
||||
}
|
||||
def _search_source_node(self, source_embedding, filters, threshold=0.9):
|
||||
"""Search for source nodes with similar embeddings."""
|
||||
user_id = filters["user_id"]
|
||||
agent_id = filters.get("agent_id", None)
|
||||
|
||||
if agent_id:
|
||||
cypher = """
|
||||
CALL vector_search.search("memzero", 1, $source_embedding)
|
||||
YIELD distance, node, similarity
|
||||
WITH node AS source_candidate, similarity
|
||||
WHERE source_candidate.user_id = $user_id
|
||||
AND source_candidate.agent_id = $agent_id
|
||||
AND similarity >= $threshold
|
||||
RETURN id(source_candidate);
|
||||
"""
|
||||
params = {
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
"agent_id": agent_id,
|
||||
"threshold": threshold,
|
||||
}
|
||||
else:
|
||||
cypher = """
|
||||
CALL vector_search.search("memzero", 1, $source_embedding)
|
||||
YIELD distance, node, similarity
|
||||
WITH node AS source_candidate, similarity
|
||||
WHERE source_candidate.user_id = $user_id
|
||||
AND similarity >= $threshold
|
||||
RETURN id(source_candidate);
|
||||
"""
|
||||
params = {
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
"threshold": threshold,
|
||||
}
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
|
||||
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
||||
cypher = """
|
||||
CALL vector_search.search("memzero", 1, $destination_embedding)
|
||||
YIELD distance, node, similarity
|
||||
WITH node AS destination_candidate, similarity
|
||||
WHERE node.user_id = $user_id AND similarity >= $threshold
|
||||
RETURN id(destination_candidate);
|
||||
"""
|
||||
params = {
|
||||
"destination_embedding": destination_embedding,
|
||||
"user_id": user_id,
|
||||
"threshold": threshold,
|
||||
}
|
||||
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
|
||||
"""Search for destination nodes with similar embeddings."""
|
||||
user_id = filters["user_id"]
|
||||
agent_id = filters.get("agent_id", None)
|
||||
|
||||
if agent_id:
|
||||
cypher = """
|
||||
CALL vector_search.search("memzero", 1, $destination_embedding)
|
||||
YIELD distance, node, similarity
|
||||
WITH node AS destination_candidate, similarity
|
||||
WHERE node.user_id = $user_id
|
||||
AND node.agent_id = $agent_id
|
||||
AND similarity >= $threshold
|
||||
RETURN id(destination_candidate);
|
||||
"""
|
||||
params = {
|
||||
"destination_embedding": destination_embedding,
|
||||
"user_id": user_id,
|
||||
"agent_id": agent_id,
|
||||
"threshold": threshold,
|
||||
}
|
||||
else:
|
||||
cypher = """
|
||||
CALL vector_search.search("memzero", 1, $destination_embedding)
|
||||
YIELD distance, node, similarity
|
||||
WITH node AS destination_candidate, similarity
|
||||
WHERE node.user_id = $user_id
|
||||
AND similarity >= $threshold
|
||||
RETURN id(destination_candidate);
|
||||
"""
|
||||
params = {
|
||||
"destination_embedding": destination_embedding,
|
||||
"user_id": user_id,
|
||||
"threshold": threshold,
|
||||
}
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user