Fix: Add Google Genai library support (#2941)

This commit is contained in:
Akshat Jain
2025-06-17 17:47:09 +05:30
committed by GitHub
parent e0003247c3
commit c70dc7614b
7 changed files with 589 additions and 276 deletions

View File

@@ -118,11 +118,19 @@ class MemoryGraph:
return search_results
def delete_all(self, filters):
cypher = """
MATCH (n {user_id: $user_id})
DETACH DELETE n
"""
params = {"user_id": filters["user_id"]}
"""Delete all nodes and relationships for a user or specific agent."""
if filters.get("agent_id"):
cypher = """
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})
DETACH DELETE n
"""
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"]}
else:
cypher = """
MATCH (n:Entity {user_id: $user_id})
DETACH DELETE n
"""
params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params)
def get_all(self, filters, limit=100):
@@ -131,20 +139,31 @@ class MemoryGraph:
Args:
filters (dict): A dictionary containing filters to be applied during the retrieval.
Supports 'user_id' (required) and 'agent_id' (optional).
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
list: A list of dictionaries, each containing:
- 'contexts': The base data store response for each memory.
- 'entities': A list of strings representing the nodes and relationships
- 'source': The source node name.
- 'relationship': The relationship type.
- 'target': The target node name.
"""
# return all nodes and relationships
query = """
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
# Build query based on whether agent_id is provided
if filters.get("agent_id"):
query = """
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity {user_id: $user_id, agent_id: $agent_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"], "limit": limit}
else:
query = """
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
params = {"user_id": filters["user_id"], "limit": limit}
results = self.graph.query(query, params=params)
final_results = []
for result in results:
@@ -241,33 +260,65 @@ class MemoryGraph:
for node in node_list:
n_embedding = self.embedding_model.embed(node)
cypher_query = """
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity)
WHERE n.embedding IS NOT NULL
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
YIELD node1, node2, similarity
WITH node1, node2, similarity, r
WHERE similarity >= $threshold
RETURN node1.user_id AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.user_id AS destination, id(node2) AS destination_id, similarity
UNION
MATCH (n:Entity {user_id: $user_id})<-[r]-(m:Entity)
WHERE n.embedding IS NOT NULL
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
YIELD node1, node2, similarity
WITH node1, node2, similarity, r
WHERE similarity >= $threshold
RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit;
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
# Build query based on whether agent_id is provided
if filters.get("agent_id"):
cypher_query = """
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity)
WHERE n.embedding IS NOT NULL
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
YIELD node1, node2, similarity
WITH node1, node2, similarity, r
WHERE similarity >= $threshold
RETURN node1.name AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.name AS destination, id(node2) AS destination_id, similarity
UNION
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})<-[r]-(m:Entity)
WHERE n.embedding IS NOT NULL
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
YIELD node1, node2, similarity
WITH node1, node2, similarity, r
WHERE similarity >= $threshold
RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit;
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"agent_id": filters["agent_id"],
"limit": limit,
}
else:
cypher_query = """
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity)
WHERE n.embedding IS NOT NULL
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
YIELD node1, node2, similarity
WITH node1, node2, similarity, r
WHERE similarity >= $threshold
RETURN node1.name AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.name AS destination, id(node2) AS destination_id, similarity
UNION
MATCH (n:Entity {user_id: $user_id})<-[r]-(m:Entity)
WHERE n.embedding IS NOT NULL
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
YIELD node1, node2, similarity
WITH node1, node2, similarity, r
WHERE similarity >= $threshold
RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit;
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
@@ -300,38 +351,54 @@ class MemoryGraph:
logger.debug(f"Deleted relationships: {to_be_deleted}")
return to_be_deleted
def _delete_entities(self, to_be_deleted, user_id):
def _delete_entities(self, to_be_deleted, filters):
"""Delete the entities from the graph."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
results = []
for item in to_be_deleted:
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
# Build the agent filter for the query
agent_filter = ""
params = {
"source_name": source,
"dest_name": destination,
"user_id": user_id,
}
if agent_id:
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
params["agent_id"] = agent_id
# Delete the specific relationship between nodes
cypher = f"""
MATCH (n:Entity {{name: $source_name, user_id: $user_id}})
-[r:{relationship}]->
(m {{name: $dest_name, user_id: $user_id}})
(m:Entity {{name: $dest_name, user_id: $user_id}})
WHERE 1=1 {agent_filter}
DELETE r
RETURN
n.name AS source,
m.name AS target,
type(r) AS relationship
"""
params = {
"source_name": source,
"dest_name": destination,
"user_id": user_id,
}
result = self.graph.query(cypher, params=params)
results.append(result)
return results
# added Entity label to all nodes for vector search to work
def _add_entities(self, to_be_added, user_id, entity_type_map):
def _add_entities(self, to_be_added, filters, entity_type_map):
"""Add the new entities to the graph. Merge the nodes if they already exist."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
results = []
for item in to_be_added:
# entities
source = item["source"]
@@ -346,18 +413,21 @@ class MemoryGraph:
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
# search for the nodes with the closest embeddings; this is basically
# comparison of one embedding to all embeddings in a graph -> vector
# search with cosine similarity metric
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
# search for the nodes with the closest embeddings
source_node_search_result = self._search_source_node(source_embedding, filters, threshold=0.9)
destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=0.9)
# Prepare agent_id for node creation
agent_id_clause = ""
if agent_id:
agent_id_clause = ", agent_id: $agent_id"
# TODO: Create a cypher query and common params for all the cases
if not destination_node_search_result and source_node_search_result:
cypher = f"""
MATCH (source:Entity)
WHERE id(source) = $source_id
MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id}})
MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id{agent_id_clause}}})
ON CREATE SET
destination.created = timestamp(),
destination.embedding = $destination_embedding,
@@ -374,11 +444,14 @@ class MemoryGraph:
"destination_embedding": dest_embedding,
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
elif destination_node_search_result and not source_node_search_result:
cypher = f"""
MATCH (destination:Entity)
WHERE id(destination) = $destination_id
MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id}})
MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
ON CREATE SET
source.created = timestamp(),
source.embedding = $source_embedding,
@@ -395,6 +468,9 @@ class MemoryGraph:
"source_embedding": source_embedding,
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
elif source_node_search_result and destination_node_search_result:
cypher = f"""
MATCH (source:Entity)
@@ -412,12 +488,15 @@ class MemoryGraph:
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
else:
cypher = f"""
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id}})
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding, n:Entity
ON MATCH SET n.embedding = $source_embedding
MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id}})
MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id{agent_id_clause}}})
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding, m:Entity
ON MATCH SET m.embedding = $dest_embedding
MERGE (n)-[rel:{relationship}]->(m)
@@ -431,6 +510,9 @@ class MemoryGraph:
"dest_embedding": dest_embedding,
"user_id": user_id,
}
if agent_id:
params["agent_id"] = agent_id
result = self.graph.query(cypher, params=params)
results.append(result)
return results
@@ -442,37 +524,80 @@ class MemoryGraph:
item["destination"] = item["destination"].lower().replace(" ", "_")
return entity_list
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
cypher = """
CALL vector_search.search("memzero", 1, $source_embedding)
YIELD distance, node, similarity
WITH node AS source_candidate, similarity
WHERE source_candidate.user_id = $user_id AND similarity >= $threshold
RETURN id(source_candidate);
"""
params = {
"source_embedding": source_embedding,
"user_id": user_id,
"threshold": threshold,
}
def _search_source_node(self, source_embedding, filters, threshold=0.9):
"""Search for source nodes with similar embeddings."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
if agent_id:
cypher = """
CALL vector_search.search("memzero", 1, $source_embedding)
YIELD distance, node, similarity
WITH node AS source_candidate, similarity
WHERE source_candidate.user_id = $user_id
AND source_candidate.agent_id = $agent_id
AND similarity >= $threshold
RETURN id(source_candidate);
"""
params = {
"source_embedding": source_embedding,
"user_id": user_id,
"agent_id": agent_id,
"threshold": threshold,
}
else:
cypher = """
CALL vector_search.search("memzero", 1, $source_embedding)
YIELD distance, node, similarity
WITH node AS source_candidate, similarity
WHERE source_candidate.user_id = $user_id
AND similarity >= $threshold
RETURN id(source_candidate);
"""
params = {
"source_embedding": source_embedding,
"user_id": user_id,
"threshold": threshold,
}
result = self.graph.query(cypher, params=params)
return result
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
cypher = """
CALL vector_search.search("memzero", 1, $destination_embedding)
YIELD distance, node, similarity
WITH node AS destination_candidate, similarity
WHERE node.user_id = $user_id AND similarity >= $threshold
RETURN id(destination_candidate);
"""
params = {
"destination_embedding": destination_embedding,
"user_id": user_id,
"threshold": threshold,
}
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
"""Search for destination nodes with similar embeddings."""
user_id = filters["user_id"]
agent_id = filters.get("agent_id", None)
if agent_id:
cypher = """
CALL vector_search.search("memzero", 1, $destination_embedding)
YIELD distance, node, similarity
WITH node AS destination_candidate, similarity
WHERE node.user_id = $user_id
AND node.agent_id = $agent_id
AND similarity >= $threshold
RETURN id(destination_candidate);
"""
params = {
"destination_embedding": destination_embedding,
"user_id": user_id,
"agent_id": agent_id,
"threshold": threshold,
}
else:
cypher = """
CALL vector_search.search("memzero", 1, $destination_embedding)
YIELD distance, node, similarity
WITH node AS destination_candidate, similarity
WHERE node.user_id = $user_id
AND similarity >= $threshold
RETURN id(destination_candidate);
"""
params = {
"destination_embedding": destination_embedding,
"user_id": user_id,
"threshold": threshold,
}
result = self.graph.query(cypher, params=params)
return result