Add Amazon Neptune Analytics graph_store configuration & integration (#2949)

This commit is contained in:
Andrew Carbonetto
2025-07-04 16:26:21 -07:00
committed by GitHub
parent 7484eed4b2
commit 05c404d8d3
12 changed files with 1823 additions and 5 deletions

View File

@@ -13,7 +13,7 @@ install:
install_all:
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow pymongo
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j langchain-aws rank-bm25 pymochow pymongo
# Format code with ruff
format:

View File

@@ -232,6 +232,66 @@ m = Memory.from_config(config_dict=config)
```
</CodeGroup>
### Initialize Neptune Analytics
Mem0 now supports Amazon Neptune Analytics as a graph store provider. This integration allows you to use Neptune Analytics for storing and querying graph-based memories.
#### Instance Setup
Create an Amazon Neptune Analytics instance in your AWS account following the [AWS documentation](https://docs.aws.amazon.com/neptune-analytics/latest/userguide/get-started.html).
- Public connectivity is not enabled by default, and if accessing from outside a VPC, it needs to be enabled.
- Once the Amazon Neptune Analytics instance is available, you will need the graph-identifier to connect.
- The Neptune Analytics instance must be created using the same vector dimensions as the embedding model creates. See: https://docs.aws.amazon.com/neptune-analytics/latest/userguide/vector-index.html
#### Attach Credentials
Configure your AWS credentials with access to your Amazon Neptune Analytics resources by following the [Configuration and credentials precedence](https://docs.aws.amazon.com/cli/v1/userguide/cli-chap-configure.html#configure-precedence).
- For example, add your SSH access key session token via environment variables:
```bash
export AWS_ACCESS_KEY_ID=your-access-key
export AWS_SECRET_ACCESS_KEY=your-secret-key
export AWS_SESSION_TOKEN=your-session-token
export AWS_DEFAULT_REGION=your-region
```
- The IAM user or role making the request must have a policy attached that allows one of the following IAM actions in that neptune-graph:
- neptune-graph:ReadDataViaQuery
- neptune-graph:WriteDataViaQuery
- neptune-graph:DeleteDataViaQuery
#### Usage
The Neptune memory store uses AWS LangChain Python API to connect to Neptune instances. For additional configuration options for connecting to your Amazon Neptune Analytics instance see [AWS LangChain API documentation](https://python.langchain.com/api_reference/aws/graphs/langchain_aws.graphs.neptune_graph.NeptuneAnalyticsGraph.html).
<CodeGroup>
```python Python
from mem0 import Memory
# This example must connect to a neptune-graph instance with 1536 vector dimensions specified.
config = {
"embedder": {
"provider": "openai",
"config": {"model": "text-embedding-3-large", "embedding_dims": 1536},
},
"graph_store": {
"provider": "neptune",
"config": {
"endpoint": "neptune-graph://<GRAPH_ID>",
},
},
}
m = Memory.from_config(config_dict=config)
```
</CodeGroup>
#### Troubleshooting
- For issues connecting to Amazon Neptune Analytics, please refer to the [Connecting to a graph guide](https://docs.aws.amazon.com/neptune-analytics/latest/userguide/gettingStarted-connecting.html).
- For issues related to authentication, refer to the [boto3 client configuration options](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html).
- For more details on how to connect, configure, and use the graph_memory graph store, see the [Neptune Analytics example notebook](examples/graph-db-demo/neptune-analytics-example.ipynb).
## Graph Operations
The Mem0's graph supports the following operations:

View File

@@ -0,0 +1,593 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Neptune as Graph Memory\n",
"\n",
"In this notebook, we will be connecting using a Amazon Neptune Analytics instance as our memory graph storage for Mem0.\n",
"\n",
"The Graph Memory storage persists memories in a graph or relationship form when performing `m.add` memory operations. It then uses vector distance algorithms to find related memories during a `m.search` operation. Relationships are returned in the result, and add context to the memories.\n",
"\n",
"Reference: [Vector Similarity using Neptune Analytics](https://docs.aws.amazon.com/neptune-analytics/latest/userguide/vector-similarity.html)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prerequisites\n",
"\n",
"### 1. Install Mem0 with Graph Memory support \n",
"\n",
"To use Mem0 with Graph Memory support, install it using pip:\n",
"\n",
"```bash\n",
"pip install \"mem0ai[graph]\"\n",
"```\n",
"\n",
"This command installs Mem0 along with the necessary dependencies for graph functionality.\n",
"\n",
"### 2. Connect to Neptune\n",
"\n",
"To connect to Amazon Neptune Analytics, you need to configure Neptune with your Amazon profile credentials. The best way to do this is to declare environment variables with IAM permission to your Neptune Analytics instance. The `graph-identifier` for the instance to persist memories needs to be defined in the Mem0 configuration under `\"graph_store\"`, with the `\"neptune\"` provider. Note that the Neptune Analytics instance needs to have `vector-search-configuration` defined to meet the needs of the llm model's vector dimensions, see: https://docs.aws.amazon.com/neptune-analytics/latest/userguide/vector-index.html.\n",
"\n",
"```python\n",
"embedding_dimensions = 1536\n",
"graph_identifier = \"<MY-GRAPH>\" # graph with 1536 dimensions for vector search\n",
"config = {\n",
" \"embedder\": {\n",
" \"provider\": \"openai\",\n",
" \"config\": {\n",
" \"model\": \"text-embedding-3-large\",\n",
" \"embedding_dims\": embedding_dimensions\n",
" },\n",
" },\n",
" \"graph_store\": {\n",
" \"provider\": \"neptune\",\n",
" \"config\": {\n",
" \"endpoint\": f\"neptune-graph://{graph_identifier}\",\n",
" },\n",
" },\n",
"}\n",
"```\n",
"\n",
"### 3. Configure OpenSearch\n",
"\n",
"We're going to use OpenSearch as our vector store. You can run [OpenSearch from docker image](https://docs.opensearch.org/docs/latest/install-and-configure/install-opensearch/docker/):\n",
"\n",
"```bash\n",
"docker pull opensearchproject/opensearch:2\n",
"```\n",
"\n",
"And verify that it's running with a `<custom-admin-password>`:\n",
"\n",
"```bash\n",
" docker run -d -p 9200:9200 -p 9600:9600 -e \"discovery.type=single-node\" -e \"OPENSEARCH_INITIAL_ADMIN_PASSWORD=<custom-admin-password>\" opensearchproject/opensearch:latest\n",
"\n",
" curl https://localhost:9200 -ku admin:<custom-admin-password>\n",
"```\n",
"\n",
"We're going to connect [OpenSearch using the python client](https://github.com/opensearch-project/opensearch-py):\n",
"\n",
"```bash\n",
"pip install \"opensearch-py\"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configuration\n",
"\n",
"Do all the imports and configure OpenAI (enter your OpenAI API key):"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-07-03T20:52:48.330121Z",
"start_time": "2025-07-03T20:52:47.092369Z"
}
},
"source": [
"from mem0 import Memory\n",
"import os\n",
"import logging\n",
"import sys\n",
"\n",
"logging.getLogger(\"mem0.graphs.neptune.main\").setLevel(logging.DEBUG)\n",
"logging.getLogger(\"mem0.graphs.neptune.base\").setLevel(logging.DEBUG)\n",
"logger = logging.getLogger(__name__)\n",
"logger.setLevel(logging.DEBUG)\n",
"\n",
"logging.basicConfig(\n",
" format=\"%(levelname)s - %(message)s\",\n",
" datefmt=\"%Y-%m-%d %H:%M:%S\",\n",
" stream=sys.stdout, # Explicitly set output to stdout\n",
")"
],
"outputs": [],
"execution_count": 1
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Setup the Mem0 configuration using:\n",
"- openai as the embedder\n",
"- Amazon Neptune Analytics instance as a graph store\n",
"- OpenSearch as the vector store"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-07-03T20:52:50.958741Z",
"start_time": "2025-07-03T20:52:50.955127Z"
}
},
"source": [
"graph_identifier = os.environ.get(\"GRAPH_ID\")\n",
"opensearch_username = os.environ.get(\"OS_USERNAME\")\n",
"opensearch_password = os.environ.get(\"OS_PASSWORD\")\n",
"config = {\n",
" \"embedder\": {\n",
" \"provider\": \"openai\",\n",
" \"config\": {\"model\": \"text-embedding-3-large\", \"embedding_dims\": 1536},\n",
" },\n",
" \"graph_store\": {\n",
" \"provider\": \"neptune\",\n",
" \"config\": {\n",
" \"endpoint\": f\"neptune-graph://{graph_identifier}\",\n",
" },\n",
" },\n",
" \"vector_store\": {\n",
" \"provider\": \"opensearch\",\n",
" \"config\": {\n",
" \"collection_name\": \"vector_store\",\n",
" \"host\": \"localhost\",\n",
" \"port\": 9200,\n",
" \"user\": opensearch_username,\n",
" \"password\": opensearch_password,\n",
" \"use_ssl\": False,\n",
" \"verify_certs\": False,\n",
" },\n",
" },\n",
"}"
],
"outputs": [],
"execution_count": 2
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Graph Memory initializiation\n",
"\n",
"Initialize Memgraph as a Graph Memory store:"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-07-03T20:52:55.655673Z",
"start_time": "2025-07-03T20:52:54.141041Z"
}
},
"source": [
"m = Memory.from_config(config_dict=config)\n",
"\n",
"app_id = \"movies\"\n",
"user_id = \"alice\"\n",
"\n",
"m.delete_all(user_id=user_id)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING - Creating index vector_store, it might take 1-2 minutes...\n",
"WARNING - Creating index mem0migrations, it might take 1-2 minutes...\n",
"DEBUG - delete_all query=\n",
" MATCH (n {user_id: $user_id})\n",
" DETACH DELETE n\n",
" \n"
]
},
{
"data": {
"text/plain": [
"{'message': 'Memories deleted successfully!'}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Store memories\n",
"\n",
"Create memories and store one at a time:"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-07-03T20:53:05.338249Z",
"start_time": "2025-07-03T20:52:57.528210Z"
}
},
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"I'm planning to watch a movie tonight. Any recommendations?\",\n",
" },\n",
"]\n",
"\n",
"# Store inferred memories (default behavior)\n",
"result = m.add(messages, user_id=user_id, metadata={\"category\": \"movie_recommendations\"})\n",
"\n",
"all_results = m.get_all(user_id=user_id)\n",
"for n in all_results[\"results\"]:\n",
" print(f\"node \\\"{n['memory']}\\\": [hash: {n['hash']}]\")\n",
"\n",
"for e in all_results[\"relations\"]:\n",
" print(f\"edge \\\"{e['source']}\\\" --{e['relationship']}--> \\\"{e['target']}\\\"\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DEBUG - Extracted entities: [{'source': 'alice', 'relationship': 'plans_to_watch', 'destination': 'movie'}]\n",
"DEBUG - _search_graph_db\n",
" query=\n",
" MATCH (n )\n",
" WHERE n.user_id = $user_id\n",
" WITH n, $n_embedding as n_embedding\n",
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
" n_embedding,\n",
" n,\n",
" {metric:\"CosineSimilarity\"}\n",
" ) YIELD distance\n",
" WITH n, distance as similarity\n",
" WHERE similarity >= $threshold\n",
" CALL {\n",
" WITH n\n",
" MATCH (n)-[r]->(m) \n",
" RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id\n",
" UNION ALL\n",
" WITH n\n",
" MATCH (m)-[r]->(n) \n",
" RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id\n",
" }\n",
" WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity\n",
" RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity\n",
" ORDER BY similarity DESC\n",
" LIMIT $limit\n",
" \n",
"DEBUG - Deleted relationships: []\n",
"DEBUG - _search_source_node\n",
" query=\n",
" MATCH (source_candidate )\n",
" WHERE source_candidate.user_id = $user_id \n",
"\n",
" WITH source_candidate, $source_embedding as v_embedding\n",
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
" v_embedding,\n",
" source_candidate,\n",
" {metric:\"CosineSimilarity\"}\n",
" ) YIELD distance\n",
" WITH source_candidate, distance AS cosine_similarity\n",
" WHERE cosine_similarity >= $threshold\n",
"\n",
" WITH source_candidate, cosine_similarity\n",
" ORDER BY cosine_similarity DESC\n",
" LIMIT 1\n",
"\n",
" RETURN id(source_candidate), cosine_similarity\n",
" \n",
"DEBUG - _search_destination_node\n",
" query=\n",
" MATCH (destination_candidate )\n",
" WHERE destination_candidate.user_id = $user_id\n",
" \n",
" WITH destination_candidate, $destination_embedding as v_embedding\n",
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
" v_embedding,\n",
" destination_candidate, \n",
" {metric:\"CosineSimilarity\"}\n",
" ) YIELD distance\n",
" WITH destination_candidate, distance AS cosine_similarity\n",
" WHERE cosine_similarity >= $threshold\n",
"\n",
" WITH destination_candidate, cosine_similarity\n",
" ORDER BY cosine_similarity DESC\n",
" LIMIT 1\n",
" \n",
" RETURN id(destination_candidate), cosine_similarity\n",
" \n",
"DEBUG - _add_entities:\n",
" destination_node_search_result=[]\n",
" source_node_search_result=[]\n",
" query=\n",
" MERGE (n :`__User__` {name: $source_name, user_id: $user_id})\n",
" ON CREATE SET n.created = timestamp(),\n",
" n.mentions = 1\n",
" \n",
" ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1\n",
" WITH n, $source_embedding as source_embedding\n",
" CALL neptune.algo.vectors.upsert(n, source_embedding)\n",
" WITH n\n",
" MERGE (m :`entertainment` {name: $dest_name, user_id: $user_id})\n",
" ON CREATE SET m.created = timestamp(),\n",
" m.mentions = 1\n",
" \n",
" ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1\n",
" WITH n, m, $dest_embedding as dest_embedding\n",
" CALL neptune.algo.vectors.upsert(m, dest_embedding)\n",
" WITH n, m\n",
" MERGE (n)-[rel:plans_to_watch]->(m)\n",
" ON CREATE SET rel.created = timestamp(), rel.mentions = 1\n",
" ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1\n",
" RETURN n.name AS source, type(rel) AS relationship, m.name AS target\n",
" \n",
"DEBUG - Retrieved 1 relationships\n",
"node \"Planning to watch a movie tonight\": [hash: bf55418607cfdca4afa311b5fd8496bd]\n",
"edge \"alice\" --plans_to_watch--> \"movie\"\n"
]
}
],
"execution_count": 4
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-07-03T20:53:17.755933Z",
"start_time": "2025-07-03T20:53:11.568772Z"
}
},
"source": [
"messages = [\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"How about a thriller movies? They can be quite engaging.\",\n",
" },\n",
"]\n",
"\n",
"# Store inferred memories (default behavior)\n",
"result = m.add(messages, user_id=user_id, metadata={\"category\": \"movie_recommendations\"})\n",
"\n",
"all_results = m.get_all(user_id=user_id)\n",
"for n in all_results[\"results\"]:\n",
" print(f\"node \\\"{n['memory']}\\\": [hash: {n['hash']}]\")\n",
"\n",
"for e in all_results[\"relations\"]:\n",
" print(f\"edge \\\"{e['source']}\\\" --{e['relationship']}--> \\\"{e['target']}\\\"\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DEBUG - Extracted entities: [{'source': 'thriller_movies', 'relationship': 'is_engaging', 'destination': 'thriller_movies'}]\n",
"DEBUG - _search_graph_db\n",
" query=\n",
" MATCH (n )\n",
" WHERE n.user_id = $user_id\n",
" WITH n, $n_embedding as n_embedding\n",
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
" n_embedding,\n",
" n,\n",
" {metric:\"CosineSimilarity\"}\n",
" ) YIELD distance\n",
" WITH n, distance as similarity\n",
" WHERE similarity >= $threshold\n",
" CALL {\n",
" WITH n\n",
" MATCH (n)-[r]->(m) \n",
" RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id\n",
" UNION ALL\n",
" WITH n\n",
" MATCH (m)-[r]->(n) \n",
" RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id\n",
" }\n",
" WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity\n",
" RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity\n",
" ORDER BY similarity DESC\n",
" LIMIT $limit\n",
" \n",
"DEBUG - Deleted relationships: []\n",
"DEBUG - _search_source_node\n",
" query=\n",
" MATCH (source_candidate )\n",
" WHERE source_candidate.user_id = $user_id \n",
"\n",
" WITH source_candidate, $source_embedding as v_embedding\n",
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
" v_embedding,\n",
" source_candidate,\n",
" {metric:\"CosineSimilarity\"}\n",
" ) YIELD distance\n",
" WITH source_candidate, distance AS cosine_similarity\n",
" WHERE cosine_similarity >= $threshold\n",
"\n",
" WITH source_candidate, cosine_similarity\n",
" ORDER BY cosine_similarity DESC\n",
" LIMIT 1\n",
"\n",
" RETURN id(source_candidate), cosine_similarity\n",
" \n",
"DEBUG - _search_destination_node\n",
" query=\n",
" MATCH (destination_candidate )\n",
" WHERE destination_candidate.user_id = $user_id\n",
" \n",
" WITH destination_candidate, $destination_embedding as v_embedding\n",
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
" v_embedding,\n",
" destination_candidate, \n",
" {metric:\"CosineSimilarity\"}\n",
" ) YIELD distance\n",
" WITH destination_candidate, distance AS cosine_similarity\n",
" WHERE cosine_similarity >= $threshold\n",
"\n",
" WITH destination_candidate, cosine_similarity\n",
" ORDER BY cosine_similarity DESC\n",
" LIMIT 1\n",
" \n",
" RETURN id(destination_candidate), cosine_similarity\n",
" \n",
"DEBUG - _add_entities:\n",
" destination_node_search_result=[{'id(destination_candidate)': '67c49d52-e305-47fe-9fce-2cd5adc5d83c0', 'cosine_similarity': 0.999999}]\n",
" source_node_search_result=[{'id(source_candidate)': '67c49d52-e305-47fe-9fce-2cd5adc5d83c0', 'cosine_similarity': 0.999999}]\n",
" query=\n",
" MATCH (source)\n",
" WHERE id(source) = $source_id\n",
" SET source.mentions = coalesce(source.mentions, 0) + 1\n",
" WITH source\n",
" MATCH (destination)\n",
" WHERE id(destination) = $destination_id\n",
" SET destination.mentions = coalesce(destination.mentions) + 1\n",
" MERGE (source)-[r:is_engaging]->(destination)\n",
" ON CREATE SET \n",
" r.created_at = timestamp(),\n",
" r.updated_at = timestamp(),\n",
" r.mentions = 1\n",
" ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1\n",
" RETURN source.name AS source, type(r) AS relationship, destination.name AS target\n",
" \n",
"DEBUG - Retrieved 3 relationships\n",
"node \"Planning to watch a movie tonight\": [hash: bf55418607cfdca4afa311b5fd8496bd]\n",
"edge \"thriller_movies\" --is_a_type_of--> \"movie\"\n",
"edge \"alice\" --plans_to_watch--> \"movie\"\n",
"edge \"thriller_movies\" --is_engaging--> \"thriller_movies\"\n"
]
}
],
"execution_count": 6
},
{
"cell_type": "code",
"metadata": {
"jupyter": {
"is_executing": true
},
"ExecuteTime": {
"start_time": "2025-07-03T20:53:17.775656Z"
}
},
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"I'm not a big fan of thriller movies but I love sci-fi movies.\",\n",
" },\n",
"]\n",
"\n",
"# Store inferred memories (default behavior)\n",
"result = m.add(messages, user_id=user_id, metadata={\"category\": \"movie_recommendations\"})\n",
"\n",
"all_results = m.get_all(user_id=user_id)\n",
"for n in all_results[\"results\"]:\n",
" print(f\"node \\\"{n['memory']}\\\": [hash: {n['hash']}]\")\n",
"\n",
"for e in all_results[\"relations\"]:\n",
" print(f\"edge \\\"{e['source']}\\\" --{e['relationship']}--> \\\"{e['target']}\\\"\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"messages = [\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
" },\n",
"]\n",
"\n",
"# Store inferred memories (default behavior)\n",
"result = m.add(messages, user_id=user_id, metadata={\"category\": \"movie_recommendations\"})\n",
"\n",
"all_results = m.get_all(user_id=user_id)\n",
"for n in all_results[\"results\"]:\n",
" print(f\"node \\\"{n['memory']}\\\": [hash: {n['hash']}]\")\n",
"\n",
"for e in all_results[\"relations\"]:\n",
" print(f\"edge \\\"{e['source']}\\\" --{e['relationship']}--> \\\"{e['target']}\\\"\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Search memories"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"search_results = m.search(\"what does alice love?\", user_id=user_id)\n",
"for result in search_results[\"results\"]:\n",
" print(f\"\\\"{result['memory']}\\\" [score: {result['score']}]\")\n",
"for relation in search_results[\"relations\"]:\n",
" print(f\"{relation}\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"m.delete_all(\"user_id\")\n",
"m.reset()"
],
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0
mem0/graphs/__init__.py Normal file
View File

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator
@@ -41,9 +41,43 @@ class MemgraphConfig(BaseModel):
return values
class NeptuneConfig(BaseModel):
endpoint: Optional[str] = (
Field(
None,
description="Endpoint to connect to a Neptune Analytics Server as neptune-graph://<graphid>",
),
)
base_label: Optional[bool] = Field(None, description="Whether to use base node label __Entity__ for all entities")
@model_validator(mode="before")
def check_host_port_or_path(cls, values):
endpoint = values.get("endpoint")
if not endpoint:
raise ValueError("Please provide 'endpoint' with the format as 'neptune-graph://<graphid>'.")
if endpoint.startswith("neptune-db://"):
raise ValueError("neptune-db server is not yet supported")
elif endpoint.startswith("neptune-graph://"):
# This is a Neptune Analytics Graph
graph_identifier = endpoint.replace("neptune-graph://", "")
if not graph_identifier.startswith("g-"):
raise ValueError("Provide a valid 'graph_identifier'.")
values["graph_identifier"] = graph_identifier
return values
else:
raise ValueError(
"You must provide an endpoint to create a NeptuneServer as either neptune-db://<endpoint> or neptune-graph://<graphid>"
)
class GraphStoreConfig(BaseModel):
provider: str = Field(description="Provider of the data store (e.g., 'neo4j')", default="neo4j")
config: Neo4jConfig = Field(description="Configuration for the specific data store", default=None)
provider: str = Field(
description="Provider of the data store (e.g., 'neo4j', 'memgraph', 'neptune')",
default="neo4j",
)
config: Union[Neo4jConfig, MemgraphConfig, NeptuneConfig] = Field(
description="Configuration for the specific data store", default=None
)
llm: Optional[LlmConfig] = Field(description="LLM configuration for querying the graph store", default=None)
custom_prompt: Optional[str] = Field(
description="Custom prompt to fetch entities from the given text", default=None
@@ -56,5 +90,7 @@ class GraphStoreConfig(BaseModel):
return Neo4jConfig(**v.model_dump())
elif provider == "memgraph":
return MemgraphConfig(**v.model_dump())
elif provider == "neptune":
return NeptuneConfig(**v.model_dump())
else:
raise ValueError(f"Unsupported graph store provider: {provider}")

View File

410
mem0/graphs/neptune/base.py Normal file
View File

@@ -0,0 +1,410 @@
import logging
from abc import ABC, abstractmethod
from mem0.memory.utils import format_entities
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL,
EXTRACT_ENTITIES_TOOL,
RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL,
)
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory
logger = logging.getLogger(__name__)
class NeptuneBase(ABC):
"""
Abstract base class for neptune (neptune analytics and neptune db) calls using OpenCypher
to store/retrieve data
"""
@staticmethod
def _create_embedding_model(config):
"""
:return: the Embedder model used for memory store
"""
return EmbedderFactory.create(
config.embedder.provider,
config.embedder.config,
{"enable_embeddings": True},
)
@staticmethod
def _create_llm(config, llm_provider):
"""
:return: the llm model used for memory store
"""
return LlmFactory.create(llm_provider, config.llm.config)
def add(self, data, filters):
"""
Adds data to the graph.
Args:
data (str): The data to add to the graph.
filters (dict): A dictionary containing filters to be applied during the addition.
"""
entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
def _retrieve_nodes_from_data(self, data, filters):
"""
Extract all entities mentioned in the query.
"""
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
},
{"role": "user", "content": data},
],
tools=_tools,
)
entity_type_map = {}
try:
for tool_call in search_results["tool_calls"]:
if tool_call["name"] != "extract_entities":
continue
for item in tool_call["arguments"]["entities"]:
entity_type_map[item["entity"]] = item["entity_type"]
except Exception as e:
logger.exception(
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
)
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
return entity_type_map
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
"""
Establish relations among the extracted nodes.
"""
if self.config.graph_store.custom_prompt:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
),
},
{"role": "user", "content": data},
]
else:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
},
{
"role": "user",
"content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}",
},
]
_tools = [RELATIONS_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [RELATIONS_STRUCT_TOOL]
extracted_entities = self.llm.generate_response(
messages=messages,
tools=_tools,
)
entities = []
if extracted_entities["tool_calls"]:
entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
entities = self._remove_spaces_from_entities(entities)
logger.debug(f"Extracted entities: {entities}")
return entities
def _remove_spaces_from_entities(self, entity_list):
for item in entity_list:
item["source"] = item["source"].lower().replace(" ", "_")
item["relationship"] = item["relationship"].lower().replace(" ", "_")
item["destination"] = item["destination"].lower().replace(" ", "_")
return entity_list
def _get_delete_entities_from_search_output(self, search_output, data, filters):
"""
Get the entities to be deleted from the search output.
"""
search_output_string = format_entities(search_output)
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
_tools = [DELETE_MEMORY_TOOL_GRAPH]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
]
memory_updates = self.llm.generate_response(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
tools=_tools,
)
to_be_deleted = []
for item in memory_updates["tool_calls"]:
if item["name"] == "delete_graph_memory":
to_be_deleted.append(item["arguments"])
# in case if it is not in the correct format
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
logger.debug(f"Deleted relationships: {to_be_deleted}")
return to_be_deleted
def _delete_entities(self, to_be_deleted, user_id):
"""
Delete the entities from the graph.
"""
results = []
for item in to_be_deleted:
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
# Delete the specific relationship between nodes
cypher, params = self._delete_entities_cypher(source, destination, relationship, user_id)
result = self.graph.query(cypher, params=params)
results.append(result)
return results
@abstractmethod
def _delete_entities_cypher(self, source, destination, relationship, user_id):
"""
Returns the OpenCypher query and parameters for deleting entities in the graph DB
"""
pass
def _add_entities(self, to_be_added, user_id, entity_type_map):
"""
Add the new entities to the graph. Merge the nodes if they already exist.
"""
results = []
for item in to_be_added:
# entities
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
# types
source_type = entity_type_map.get(source, "__User__")
destination_type = entity_type_map.get(destination, "__User__")
# embeddings
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
# search for the nodes with the closest embeddings
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
cypher, params = self._add_entities_cypher(
source_node_search_result,
source,
source_embedding,
source_type,
destination_node_search_result,
destination,
dest_embedding,
destination_type,
relationship,
user_id,
)
result = self.graph.query(cypher, params=params)
results.append(result)
return results
@abstractmethod
def _add_entities_cypher(
self,
source_node_list,
source,
source_embedding,
source_type,
destination_node_list,
destination,
dest_embedding,
destination_type,
relationship,
user_id,
):
"""
Returns the OpenCypher query and parameters for adding entities in the graph DB
"""
pass
def search(self, query, filters, limit=100):
"""
Search for memories and related graph data.
Args:
query (str): Query to search for.
filters (dict): A dictionary containing filters to be applied during the search.
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
dict: A dictionary containing:
- "contexts": List of search results from the base data store.
- "entities": List of related graph data based on the query.
"""
entity_type_map = self._retrieve_nodes_from_data(query, filters)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
if not search_output:
return []
search_outputs_sequence = [
[item["source"], item["relationship"], item["destination"]] for item in search_output
]
bm25 = BM25Okapi(search_outputs_sequence)
tokenized_query = query.split(" ")
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
search_results = []
for item in reranked_results:
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
return search_results
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
cypher, params = self._search_source_node_cypher(source_embedding, user_id, threshold)
result = self.graph.query(cypher, params=params)
return result
@abstractmethod
def _search_source_node_cypher(self, source_embedding, user_id, threshold):
"""
Returns the OpenCypher query and parameters to search for source nodes
"""
pass
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
cypher, params = self._search_destination_node_cypher(destination_embedding, user_id, threshold)
result = self.graph.query(cypher, params=params)
return result
@abstractmethod
def _search_destination_node_cypher(self, destination_embedding, user_id, threshold):
"""
Returns the OpenCypher query and parameters to search for destination nodes
"""
pass
def delete_all(self, filters):
cypher, params = self._delete_all_cypher(filters)
self.graph.query(cypher, params=params)
@abstractmethod
def _delete_all_cypher(self, filters):
"""
Returns the OpenCypher query and parameters to delete all edges/nodes in the memory store
"""
pass
def get_all(self, filters, limit=100):
"""
Retrieves all nodes and relationships from the graph database based on filtering criteria.
Args:
filters (dict): A dictionary containing filters to be applied during the retrieval.
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
list: A list of dictionaries, each containing:
- 'contexts': The base data store response for each memory.
- 'entities': A list of strings representing the nodes and relationships
"""
# return all nodes and relationships
query, params = self._get_all_cypher(filters, limit)
results = self.graph.query(query, params=params)
final_results = []
for result in results:
final_results.append(
{
"source": result["source"],
"relationship": result["relationship"],
"target": result["target"],
}
)
logger.debug(f"Retrieved {len(final_results)} relationships")
return final_results
@abstractmethod
def _get_all_cypher(self, filters, limit):
"""
Returns the OpenCypher query and parameters to get all edges/nodes in the memory store
"""
pass
def _search_graph_db(self, node_list, filters, limit=100):
"""
Search similar nodes among and their respective incoming and outgoing relations.
"""
result_relations = []
for node in node_list:
n_embedding = self.embedding_model.embed(node)
cypher_query, params = self._search_graph_db_cypher(n_embedding, filters, limit)
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
return result_relations
@abstractmethod
def _search_graph_db_cypher(self, n_embedding, filters, limit):
"""
Returns the OpenCypher query and parameters to search for similar nodes in the memory store
"""
pass
# Reset is not defined in base.py
def reset(self):
"""
Reset the graph by clearing all nodes and relationships.
link: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/neptune-graph/client/reset_graph.html
"""
logger.warning("Clearing graph...")
graph_id = self.graph.graph_identifier
self.graph.client.reset_graph(
graphIdentifier=graph_id,
skipSnapshot=True,
)
waiter = self.graph.client.get_waiter("graph_available")
waiter.wait(graphIdentifier=graph_id, WaiterConfig={"Delay": 10, "MaxAttempts": 60})

372
mem0/graphs/neptune/main.py Normal file
View File

@@ -0,0 +1,372 @@
import logging
from .base import NeptuneBase
try:
from langchain_aws import NeptuneAnalyticsGraph
except ImportError:
raise ImportError("langchain_aws is not installed. Please install it using 'make install_all'.")
logger = logging.getLogger(__name__)
class MemoryGraph(NeptuneBase):
def __init__(self, config):
self.config = config
self.graph = None
endpoint = self.config.graph_store.config.endpoint
if endpoint and endpoint.startswith("neptune-graph://"):
graph_identifier = endpoint.replace("neptune-graph://", "")
self.graph = NeptuneAnalyticsGraph(graph_identifier)
if not self.graph:
raise ValueError("Unable to create a Neptune client: missing 'endpoint' in config")
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
self.embedding_model = NeptuneBase._create_embedding_model(self.config)
self.llm_provider = "openai_structured"
if self.config.llm.provider:
self.llm_provider = self.config.llm.provider
if self.config.graph_store.llm:
self.llm_provider = self.config.graph_store.llm.provider
self.llm = NeptuneBase._create_llm(self.config, self.llm_provider)
self.user_id = None
self.threshold = 0.7
def _delete_entities_cypher(self, source, destination, relationship, user_id):
"""
Returns the OpenCypher query and parameters for deleting entities in the graph DB
:param source: source node
:param destination: destination node
:param relationship: relationship label
:param user_id: user_id to use
:return: str, dict
"""
cypher = f"""
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
-[r:{relationship}]->
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
DELETE r
RETURN
n.name AS source,
m.name AS target,
type(r) AS relationship
"""
params = {
"source_name": source,
"dest_name": destination,
"user_id": user_id,
}
logger.debug(f"_delete_entities\n query={cypher}")
return cypher, params
def _add_entities_cypher(
self,
source_node_list,
source,
source_embedding,
source_type,
destination_node_list,
destination,
dest_embedding,
destination_type,
relationship,
user_id,
):
"""
Returns the OpenCypher query and parameters for adding entities in the graph DB
:param source_node_list: list of source nodes
:param source: source node name
:param source_embedding: source node embedding
:param source_type: source node label
:param destination_node_list: list of dest nodes
:param destination: destination name
:param dest_embedding: destination embedding
:param destination_type: destination node label
:param relationship: relationship label
:param user_id: user id to use
:return: str, dict
"""
source_label = self.node_label if self.node_label else f":`{source_type}`"
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
# Refactor this code with the graph_memory.py implementation
if not destination_node_list and source_node_list:
cypher = f"""
MATCH (source)
WHERE id(source) = $source_id
SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source
MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}})
ON CREATE SET
destination.created = timestamp(),
destination.mentions = 1
{destination_extra_set}
ON MATCH SET
destination.mentions = coalesce(destination.mentions, 0) + 1
WITH source, destination, $dest_embedding as dest_embedding
CALL neptune.algo.vectors.upsert(destination, dest_embedding)
WITH source, destination
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created = timestamp(),
r.mentions = 1
ON MATCH SET
r.mentions = coalesce(r.mentions, 0) + 1
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"source_id": source_node_list[0]["id(source_candidate)"],
"destination_name": destination,
"dest_embedding": dest_embedding,
"user_id": user_id,
}
elif destination_node_list and not source_node_list:
cypher = f"""
MATCH (destination)
WHERE id(destination) = $destination_id
SET destination.mentions = coalesce(destination.mentions, 0) + 1
WITH destination
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
ON CREATE SET
source.created = timestamp(),
source.mentions = 1
{source_extra_set}
ON MATCH SET
source.mentions = coalesce(source.mentions, 0) + 1
WITH source, destination, $source_embedding as source_embedding
CALL neptune.algo.vectors.upsert(source, source_embedding)
WITH source, destination
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created = timestamp(),
r.mentions = 1
ON MATCH SET
r.mentions = coalesce(r.mentions, 0) + 1
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"destination_id": destination_node_list[0]["id(destination_candidate)"],
"source_name": source,
"source_embedding": source_embedding,
"user_id": user_id,
}
elif source_node_list and destination_node_list:
cypher = f"""
MATCH (source)
WHERE id(source) = $source_id
SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source
MATCH (destination)
WHERE id(destination) = $destination_id
SET destination.mentions = coalesce(destination.mentions) + 1
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created_at = timestamp(),
r.updated_at = timestamp(),
r.mentions = 1
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"source_id": source_node_list[0]["id(source_candidate)"],
"destination_id": destination_node_list[0]["id(destination_candidate)"],
"user_id": user_id,
}
else:
cypher = f"""
MERGE (n {source_label} {{name: $source_name, user_id: $user_id}})
ON CREATE SET n.created = timestamp(),
n.mentions = 1
{source_extra_set}
ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1
WITH n, $source_embedding as source_embedding
CALL neptune.algo.vectors.upsert(n, source_embedding)
WITH n
MERGE (m {destination_label} {{name: $dest_name, user_id: $user_id}})
ON CREATE SET m.created = timestamp(),
m.mentions = 1
{destination_extra_set}
ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1
WITH n, m, $dest_embedding as dest_embedding
CALL neptune.algo.vectors.upsert(m, dest_embedding)
WITH n, m
MERGE (n)-[rel:{relationship}]->(m)
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
"""
params = {
"source_name": source,
"dest_name": destination,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding,
"user_id": user_id,
}
logger.debug(
f"_add_entities:\n destination_node_search_result={destination_node_list}\n source_node_search_result={source_node_list}\n query={cypher}"
)
return cypher, params
def _search_source_node_cypher(self, source_embedding, user_id, threshold):
"""
Returns the OpenCypher query and parameters to search for source nodes
:param source_embedding: source vector
:param user_id: user_id to use
:param threshold: the threshold for similarity
:return: str, dict
"""
cypher = f"""
MATCH (source_candidate {self.node_label})
WHERE source_candidate.user_id = $user_id
WITH source_candidate, $source_embedding as v_embedding
CALL neptune.algo.vectors.distanceByEmbedding(
v_embedding,
source_candidate,
{{metric:"CosineSimilarity"}}
) YIELD distance
WITH source_candidate, distance AS cosine_similarity
WHERE cosine_similarity >= $threshold
WITH source_candidate, cosine_similarity
ORDER BY cosine_similarity DESC
LIMIT 1
RETURN id(source_candidate), cosine_similarity
"""
params = {
"source_embedding": source_embedding,
"user_id": user_id,
"threshold": threshold,
}
logger.debug(f"_search_source_node\n query={cypher}")
return cypher, params
def _search_destination_node_cypher(self, destination_embedding, user_id, threshold):
"""
Returns the OpenCypher query and parameters to search for destination nodes
:param source_embedding: source vector
:param user_id: user_id to use
:param threshold: the threshold for similarity
:return: str, dict
"""
cypher = f"""
MATCH (destination_candidate {self.node_label})
WHERE destination_candidate.user_id = $user_id
WITH destination_candidate, $destination_embedding as v_embedding
CALL neptune.algo.vectors.distanceByEmbedding(
v_embedding,
destination_candidate,
{{metric:"CosineSimilarity"}}
) YIELD distance
WITH destination_candidate, distance AS cosine_similarity
WHERE cosine_similarity >= $threshold
WITH destination_candidate, cosine_similarity
ORDER BY cosine_similarity DESC
LIMIT 1
RETURN id(destination_candidate), cosine_similarity
"""
params = {
"destination_embedding": destination_embedding,
"user_id": user_id,
"threshold": threshold,
}
logger.debug(f"_search_destination_node\n query={cypher}")
return cypher, params
def _delete_all_cypher(self, filters):
"""
Returns the OpenCypher query and parameters to delete all edges/nodes in the memory store
:param filters: search filters
:return: str, dict
"""
cypher = f"""
MATCH (n {self.node_label} {{user_id: $user_id}})
DETACH DELETE n
"""
params = {"user_id": filters["user_id"]}
logger.debug(f"delete_all query={cypher}")
return cypher, params
def _get_all_cypher(self, filters, limit):
"""
Returns the OpenCypher query and parameters to get all edges/nodes in the memory store
:param filters: search filters
:param limit: return limit
:return: str, dict
"""
cypher = f"""
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
params = {"user_id": filters["user_id"], "limit": limit}
return cypher, params
def _search_graph_db_cypher(self, n_embedding, filters, limit):
"""
Returns the OpenCypher query and parameters to search for similar nodes in the memory store
:param n_embedding: node vector
:param filters: search filters
:param limit: return limit
:return: str, dict
"""
cypher_query = f"""
MATCH (n {self.node_label})
WHERE n.user_id = $user_id
WITH n, $n_embedding as n_embedding
CALL neptune.algo.vectors.distanceByEmbedding(
n_embedding,
n,
{{metric:"CosineSimilarity"}}
) YIELD distance
WITH n, distance as similarity
WHERE similarity >= $threshold
CALL {{
WITH n
MATCH (n)-[r]->(m)
RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id
UNION ALL
WITH n
MATCH (m)-[r]->(n)
RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id
}}
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
logger.debug(f"_search_graph_db\n query={cypher_query}")
return cypher_query, params

View File

@@ -621,3 +621,12 @@ class MemoryGraph:
result = self.graph.query(cypher, params=params)
return result
# Reset is not defined in base.py
def reset(self):
"""Reset the graph by clearing all nodes and relationships."""
logger.warning("Clearing graph...")
cypher_query = """
MATCH (n) DETACH DELETE n
"""
return self.graph.query(cypher_query)

View File

@@ -138,6 +138,8 @@ class Memory(MemoryBase):
if self.config.graph_store.config:
if self.config.graph_store.provider == "memgraph":
from mem0.memory.memgraph_memory import MemoryGraph
elif self.config.graph_store.provider == "neptune":
from mem0.graphs.neptune.main import MemoryGraph
else:
from mem0.memory.graph_memory import MemoryGraph

View File

@@ -23,6 +23,7 @@ dependencies = [
[project.optional-dependencies]
graph = [
"langchain-neo4j>=0.4.0",
"langchain-aws>=0.2.23",
"neo4j>=5.23.1",
"rank-bm25>=0.2.2",
]

View File

@@ -0,0 +1,335 @@
import unittest
from unittest.mock import MagicMock, patch
import pytest
from mem0.graphs.neptune.main import MemoryGraph
from mem0.graphs.neptune.base import NeptuneBase
class TestNeptuneMemory(unittest.TestCase):
"""Test suite for the Neptune Memory implementation."""
def setUp(self):
"""Set up test fixtures before each test method."""
# Create a mock config
self.config = MagicMock()
self.config.graph_store.config.endpoint = "neptune-graph://test-graph"
self.config.graph_store.config.base_label = True
self.config.llm.provider = "openai_structured"
self.config.graph_store.llm = None
self.config.graph_store.custom_prompt = None
# Create mock for NeptuneAnalyticsGraph
self.mock_graph = MagicMock()
self.mock_graph.client.get_graph.return_value = {"status": "AVAILABLE"}
# Create mocks for static methods
self.mock_embedding_model = MagicMock()
self.mock_llm = MagicMock()
# Patch the necessary components
self.neptune_analytics_graph_patcher = patch("mem0.graphs.neptune.main.NeptuneAnalyticsGraph")
self.mock_neptune_analytics_graph = self.neptune_analytics_graph_patcher.start()
self.mock_neptune_analytics_graph.return_value = self.mock_graph
# Patch the static methods
self.create_embedding_model_patcher = patch.object(NeptuneBase, "_create_embedding_model")
self.mock_create_embedding_model = self.create_embedding_model_patcher.start()
self.mock_create_embedding_model.return_value = self.mock_embedding_model
self.create_llm_patcher = patch.object(NeptuneBase, "_create_llm")
self.mock_create_llm = self.create_llm_patcher.start()
self.mock_create_llm.return_value = self.mock_llm
# Create the MemoryGraph instance
self.memory_graph = MemoryGraph(self.config)
# Set up common test data
self.user_id = "test_user"
self.test_filters = {"user_id": self.user_id}
def tearDown(self):
"""Tear down test fixtures after each test method."""
self.neptune_analytics_graph_patcher.stop()
self.create_embedding_model_patcher.stop()
self.create_llm_patcher.stop()
def test_initialization(self):
"""Test that the MemoryGraph is initialized correctly."""
self.assertEqual(self.memory_graph.graph, self.mock_graph)
self.assertEqual(self.memory_graph.embedding_model, self.mock_embedding_model)
self.assertEqual(self.memory_graph.llm, self.mock_llm)
self.assertEqual(self.memory_graph.llm_provider, "openai_structured")
self.assertEqual(self.memory_graph.node_label, ":`__Entity__`")
self.assertEqual(self.memory_graph.threshold, 0.7)
def test_init(self):
"""Test the class init functions"""
# Create a mock config with bad endpoint
config_no_endpoint = MagicMock()
config_no_endpoint.graph_store.config.endpoint = None
# Create the MemoryGraph instance
with pytest.raises(ValueError):
MemoryGraph(config_no_endpoint)
# Create a mock config with bad endpoint
config_ndb_endpoint = MagicMock()
config_ndb_endpoint.graph_store.config.endpoint = "neptune-db://test-graph"
with pytest.raises(ValueError):
MemoryGraph(config_ndb_endpoint)
def test_add_method(self):
"""Test the add method with mocked components."""
# Mock the necessary methods that add() calls
self.memory_graph._retrieve_nodes_from_data = MagicMock(return_value={"alice": "person", "bob": "person"})
self.memory_graph._establish_nodes_relations_from_data = MagicMock(
return_value=[{"source": "alice", "relationship": "knows", "destination": "bob"}]
)
self.memory_graph._search_graph_db = MagicMock(return_value=[])
self.memory_graph._get_delete_entities_from_search_output = MagicMock(return_value=[])
self.memory_graph._delete_entities = MagicMock(return_value=[])
self.memory_graph._add_entities = MagicMock(
return_value=[{"source": "alice", "relationship": "knows", "target": "bob"}]
)
# Call the add method
result = self.memory_graph.add("Alice knows Bob", self.test_filters)
# Verify the method calls
self.memory_graph._retrieve_nodes_from_data.assert_called_once_with("Alice knows Bob", self.test_filters)
self.memory_graph._establish_nodes_relations_from_data.assert_called_once()
self.memory_graph._search_graph_db.assert_called_once()
self.memory_graph._get_delete_entities_from_search_output.assert_called_once()
self.memory_graph._delete_entities.assert_called_once_with([], self.user_id)
self.memory_graph._add_entities.assert_called_once()
# Check the result structure
self.assertIn("deleted_entities", result)
self.assertIn("added_entities", result)
def test_search_method(self):
"""Test the search method with mocked components."""
# Mock the necessary methods that search() calls
self.memory_graph._retrieve_nodes_from_data = MagicMock(return_value={"alice": "person"})
# Mock search results
mock_search_results = [
{"source": "alice", "relationship": "knows", "destination": "bob"},
{"source": "alice", "relationship": "works_with", "destination": "charlie"},
]
self.memory_graph._search_graph_db = MagicMock(return_value=mock_search_results)
# Mock BM25Okapi
with patch("mem0.graphs.neptune.base.BM25Okapi") as mock_bm25:
mock_bm25_instance = MagicMock()
mock_bm25.return_value = mock_bm25_instance
# Mock get_top_n to return reranked results
reranked_results = [["alice", "knows", "bob"], ["alice", "works_with", "charlie"]]
mock_bm25_instance.get_top_n.return_value = reranked_results
# Call the search method
result = self.memory_graph.search("Find Alice", self.test_filters, limit=5)
# Verify the method calls
self.memory_graph._retrieve_nodes_from_data.assert_called_once_with("Find Alice", self.test_filters)
self.memory_graph._search_graph_db.assert_called_once_with(node_list=["alice"], filters=self.test_filters)
# Check the result structure
self.assertEqual(len(result), 2)
self.assertEqual(result[0]["source"], "alice")
self.assertEqual(result[0]["relationship"], "knows")
self.assertEqual(result[0]["destination"], "bob")
def test_get_all_method(self):
"""Test the get_all method."""
# Mock the _get_all_cypher method
mock_cypher = "MATCH (n) RETURN n"
mock_params = {"user_id": self.user_id, "limit": 10}
self.memory_graph._get_all_cypher = MagicMock(return_value=(mock_cypher, mock_params))
# Mock the graph.query result
mock_query_result = [
{"source": "alice", "relationship": "knows", "target": "bob"},
{"source": "bob", "relationship": "works_with", "target": "charlie"},
]
self.mock_graph.query.return_value = mock_query_result
# Call the get_all method
result = self.memory_graph.get_all(self.test_filters, limit=10)
# Verify the method calls
self.memory_graph._get_all_cypher.assert_called_once_with(self.test_filters, 10)
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
# Check the result structure
self.assertEqual(len(result), 2)
self.assertEqual(result[0]["source"], "alice")
self.assertEqual(result[0]["relationship"], "knows")
self.assertEqual(result[0]["target"], "bob")
def test_delete_all_method(self):
"""Test the delete_all method."""
# Mock the _delete_all_cypher method
mock_cypher = "MATCH (n) DETACH DELETE n"
mock_params = {"user_id": self.user_id}
self.memory_graph._delete_all_cypher = MagicMock(return_value=(mock_cypher, mock_params))
# Call the delete_all method
self.memory_graph.delete_all(self.test_filters)
# Verify the method calls
self.memory_graph._delete_all_cypher.assert_called_once_with(self.test_filters)
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
def test_search_source_node(self):
"""Test the _search_source_node method."""
# Mock embedding
mock_embedding = [0.1, 0.2, 0.3]
# Mock the _search_source_node_cypher method
mock_cypher = "MATCH (n) RETURN n"
mock_params = {"source_embedding": mock_embedding, "user_id": self.user_id, "threshold": 0.9}
self.memory_graph._search_source_node_cypher = MagicMock(return_value=(mock_cypher, mock_params))
# Mock the graph.query result
mock_query_result = [{"id(source_candidate)": 123, "cosine_similarity": 0.95}]
self.mock_graph.query.return_value = mock_query_result
# Call the _search_source_node method
result = self.memory_graph._search_source_node(mock_embedding, self.user_id, threshold=0.9)
# Verify the method calls
self.memory_graph._search_source_node_cypher.assert_called_once_with(mock_embedding, self.user_id, 0.9)
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
# Check the result
self.assertEqual(result, mock_query_result)
def test_search_destination_node(self):
"""Test the _search_destination_node method."""
# Mock embedding
mock_embedding = [0.1, 0.2, 0.3]
# Mock the _search_destination_node_cypher method
mock_cypher = "MATCH (n) RETURN n"
mock_params = {"destination_embedding": mock_embedding, "user_id": self.user_id, "threshold": 0.9}
self.memory_graph._search_destination_node_cypher = MagicMock(return_value=(mock_cypher, mock_params))
# Mock the graph.query result
mock_query_result = [{"id(destination_candidate)": 456, "cosine_similarity": 0.92}]
self.mock_graph.query.return_value = mock_query_result
# Call the _search_destination_node method
result = self.memory_graph._search_destination_node(mock_embedding, self.user_id, threshold=0.9)
# Verify the method calls
self.memory_graph._search_destination_node_cypher.assert_called_once_with(mock_embedding, self.user_id, 0.9)
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
# Check the result
self.assertEqual(result, mock_query_result)
def test_search_graph_db(self):
"""Test the _search_graph_db method."""
# Mock node list
node_list = ["alice", "bob"]
# Mock embedding
mock_embedding = [0.1, 0.2, 0.3]
self.mock_embedding_model.embed.return_value = mock_embedding
# Mock the _search_graph_db_cypher method
mock_cypher = "MATCH (n) RETURN n"
mock_params = {"n_embedding": mock_embedding, "user_id": self.user_id, "threshold": 0.7, "limit": 10}
self.memory_graph._search_graph_db_cypher = MagicMock(return_value=(mock_cypher, mock_params))
# Mock the graph.query results
mock_query_result1 = [{"source": "alice", "relationship": "knows", "destination": "bob"}]
mock_query_result2 = [{"source": "bob", "relationship": "works_with", "destination": "charlie"}]
self.mock_graph.query.side_effect = [mock_query_result1, mock_query_result2]
# Call the _search_graph_db method
result = self.memory_graph._search_graph_db(node_list, self.test_filters, limit=10)
# Verify the method calls
self.assertEqual(self.mock_embedding_model.embed.call_count, 2)
self.assertEqual(self.memory_graph._search_graph_db_cypher.call_count, 2)
self.assertEqual(self.mock_graph.query.call_count, 2)
# Check the result
expected_result = mock_query_result1 + mock_query_result2
self.assertEqual(result, expected_result)
def test_add_entities(self):
"""Test the _add_entities method."""
# Mock data
to_be_added = [{"source": "alice", "relationship": "knows", "destination": "bob"}]
entity_type_map = {"alice": "person", "bob": "person"}
# Mock embeddings
mock_embedding = [0.1, 0.2, 0.3]
self.mock_embedding_model.embed.return_value = mock_embedding
# Mock search results
mock_source_search = [{"id(source_candidate)": 123, "cosine_similarity": 0.95}]
mock_dest_search = [{"id(destination_candidate)": 456, "cosine_similarity": 0.92}]
# Mock the search methods
self.memory_graph._search_source_node = MagicMock(return_value=mock_source_search)
self.memory_graph._search_destination_node = MagicMock(return_value=mock_dest_search)
# Mock the _add_entities_cypher method
mock_cypher = "MATCH (n) RETURN n"
mock_params = {"source_id": 123, "destination_id": 456}
self.memory_graph._add_entities_cypher = MagicMock(return_value=(mock_cypher, mock_params))
# Mock the graph.query result
mock_query_result = [{"source": "alice", "relationship": "knows", "target": "bob"}]
self.mock_graph.query.return_value = mock_query_result
# Call the _add_entities method
result = self.memory_graph._add_entities(to_be_added, self.user_id, entity_type_map)
# Verify the method calls
self.assertEqual(self.mock_embedding_model.embed.call_count, 2)
self.memory_graph._search_source_node.assert_called_once_with(mock_embedding, self.user_id, threshold=0.9)
self.memory_graph._search_destination_node.assert_called_once_with(mock_embedding, self.user_id, threshold=0.9)
self.memory_graph._add_entities_cypher.assert_called_once()
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
# Check the result
self.assertEqual(result, [mock_query_result])
def test_delete_entities(self):
"""Test the _delete_entities method."""
# Mock data
to_be_deleted = [{"source": "alice", "relationship": "knows", "destination": "bob"}]
# Mock the _delete_entities_cypher method
mock_cypher = "MATCH (n) RETURN n"
mock_params = {"source_name": "alice", "dest_name": "bob", "user_id": self.user_id}
self.memory_graph._delete_entities_cypher = MagicMock(return_value=(mock_cypher, mock_params))
# Mock the graph.query result
mock_query_result = [{"source": "alice", "relationship": "knows", "target": "bob"}]
self.mock_graph.query.return_value = mock_query_result
# Call the _delete_entities method
result = self.memory_graph._delete_entities(to_be_deleted, self.user_id)
# Verify the method calls
self.memory_graph._delete_entities_cypher.assert_called_once_with("alice", "bob", "knows", self.user_id)
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
# Check the result
self.assertEqual(result, [mock_query_result])
if __name__ == "__main__":
unittest.main()