Add Amazon Neptune Analytics graph_store configuration & integration (#2949)
This commit is contained in:
committed by
GitHub
parent
7484eed4b2
commit
05c404d8d3
4
Makefile
4
Makefile
@@ -12,8 +12,8 @@ install:
|
||||
|
||||
install_all:
|
||||
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
||||
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
|
||||
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow pymongo
|
||||
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
|
||||
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j langchain-aws rank-bm25 pymochow pymongo
|
||||
|
||||
# Format code with ruff
|
||||
format:
|
||||
|
||||
@@ -232,6 +232,66 @@ m = Memory.from_config(config_dict=config)
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
### Initialize Neptune Analytics
|
||||
|
||||
Mem0 now supports Amazon Neptune Analytics as a graph store provider. This integration allows you to use Neptune Analytics for storing and querying graph-based memories.
|
||||
|
||||
#### Instance Setup
|
||||
|
||||
Create an Amazon Neptune Analytics instance in your AWS account following the [AWS documentation](https://docs.aws.amazon.com/neptune-analytics/latest/userguide/get-started.html).
|
||||
- Public connectivity is not enabled by default, and if accessing from outside a VPC, it needs to be enabled.
|
||||
- Once the Amazon Neptune Analytics instance is available, you will need the graph-identifier to connect.
|
||||
- The Neptune Analytics instance must be created using the same vector dimensions as the embedding model creates. See: https://docs.aws.amazon.com/neptune-analytics/latest/userguide/vector-index.html
|
||||
|
||||
#### Attach Credentials
|
||||
|
||||
Configure your AWS credentials with access to your Amazon Neptune Analytics resources by following the [Configuration and credentials precedence](https://docs.aws.amazon.com/cli/v1/userguide/cli-chap-configure.html#configure-precedence).
|
||||
- For example, add your SSH access key session token via environment variables:
|
||||
```bash
|
||||
export AWS_ACCESS_KEY_ID=your-access-key
|
||||
export AWS_SECRET_ACCESS_KEY=your-secret-key
|
||||
export AWS_SESSION_TOKEN=your-session-token
|
||||
export AWS_DEFAULT_REGION=your-region
|
||||
```
|
||||
- The IAM user or role making the request must have a policy attached that allows one of the following IAM actions in that neptune-graph:
|
||||
- neptune-graph:ReadDataViaQuery
|
||||
- neptune-graph:WriteDataViaQuery
|
||||
- neptune-graph:DeleteDataViaQuery
|
||||
|
||||
#### Usage
|
||||
|
||||
The Neptune memory store uses AWS LangChain Python API to connect to Neptune instances. For additional configuration options for connecting to your Amazon Neptune Analytics instance see [AWS LangChain API documentation](https://python.langchain.com/api_reference/aws/graphs/langchain_aws.graphs.neptune_graph.NeptuneAnalyticsGraph.html).
|
||||
|
||||
<CodeGroup>
|
||||
```python Python
|
||||
from mem0 import Memory
|
||||
|
||||
# This example must connect to a neptune-graph instance with 1536 vector dimensions specified.
|
||||
config = {
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {"model": "text-embedding-3-large", "embedding_dims": 1536},
|
||||
},
|
||||
"graph_store": {
|
||||
"provider": "neptune",
|
||||
"config": {
|
||||
"endpoint": "neptune-graph://<GRAPH_ID>",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
m = Memory.from_config(config_dict=config)
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
- For issues connecting to Amazon Neptune Analytics, please refer to the [Connecting to a graph guide](https://docs.aws.amazon.com/neptune-analytics/latest/userguide/gettingStarted-connecting.html).
|
||||
|
||||
- For issues related to authentication, refer to the [boto3 client configuration options](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html).
|
||||
|
||||
- For more details on how to connect, configure, and use the graph_memory graph store, see the [Neptune Analytics example notebook](examples/graph-db-demo/neptune-analytics-example.ipynb).
|
||||
|
||||
## Graph Operations
|
||||
The Mem0's graph supports the following operations:
|
||||
|
||||
|
||||
593
examples/graph-db-demo/neptune-example.ipynb
Normal file
593
examples/graph-db-demo/neptune-example.ipynb
Normal file
@@ -0,0 +1,593 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Neptune as Graph Memory\n",
|
||||
"\n",
|
||||
"In this notebook, we will be connecting using a Amazon Neptune Analytics instance as our memory graph storage for Mem0.\n",
|
||||
"\n",
|
||||
"The Graph Memory storage persists memories in a graph or relationship form when performing `m.add` memory operations. It then uses vector distance algorithms to find related memories during a `m.search` operation. Relationships are returned in the result, and add context to the memories.\n",
|
||||
"\n",
|
||||
"Reference: [Vector Similarity using Neptune Analytics](https://docs.aws.amazon.com/neptune-analytics/latest/userguide/vector-similarity.html)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prerequisites\n",
|
||||
"\n",
|
||||
"### 1. Install Mem0 with Graph Memory support \n",
|
||||
"\n",
|
||||
"To use Mem0 with Graph Memory support, install it using pip:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install \"mem0ai[graph]\"\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"This command installs Mem0 along with the necessary dependencies for graph functionality.\n",
|
||||
"\n",
|
||||
"### 2. Connect to Neptune\n",
|
||||
"\n",
|
||||
"To connect to Amazon Neptune Analytics, you need to configure Neptune with your Amazon profile credentials. The best way to do this is to declare environment variables with IAM permission to your Neptune Analytics instance. The `graph-identifier` for the instance to persist memories needs to be defined in the Mem0 configuration under `\"graph_store\"`, with the `\"neptune\"` provider. Note that the Neptune Analytics instance needs to have `vector-search-configuration` defined to meet the needs of the llm model's vector dimensions, see: https://docs.aws.amazon.com/neptune-analytics/latest/userguide/vector-index.html.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"embedding_dimensions = 1536\n",
|
||||
"graph_identifier = \"<MY-GRAPH>\" # graph with 1536 dimensions for vector search\n",
|
||||
"config = {\n",
|
||||
" \"embedder\": {\n",
|
||||
" \"provider\": \"openai\",\n",
|
||||
" \"config\": {\n",
|
||||
" \"model\": \"text-embedding-3-large\",\n",
|
||||
" \"embedding_dims\": embedding_dimensions\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"graph_store\": {\n",
|
||||
" \"provider\": \"neptune\",\n",
|
||||
" \"config\": {\n",
|
||||
" \"endpoint\": f\"neptune-graph://{graph_identifier}\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### 3. Configure OpenSearch\n",
|
||||
"\n",
|
||||
"We're going to use OpenSearch as our vector store. You can run [OpenSearch from docker image](https://docs.opensearch.org/docs/latest/install-and-configure/install-opensearch/docker/):\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"docker pull opensearchproject/opensearch:2\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"And verify that it's running with a `<custom-admin-password>`:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
" docker run -d -p 9200:9200 -p 9600:9600 -e \"discovery.type=single-node\" -e \"OPENSEARCH_INITIAL_ADMIN_PASSWORD=<custom-admin-password>\" opensearchproject/opensearch:latest\n",
|
||||
"\n",
|
||||
" curl https://localhost:9200 -ku admin:<custom-admin-password>\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"We're going to connect [OpenSearch using the python client](https://github.com/opensearch-project/opensearch-py):\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install \"opensearch-py\"\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Configuration\n",
|
||||
"\n",
|
||||
"Do all the imports and configure OpenAI (enter your OpenAI API key):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-07-03T20:52:48.330121Z",
|
||||
"start_time": "2025-07-03T20:52:47.092369Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"from mem0 import Memory\n",
|
||||
"import os\n",
|
||||
"import logging\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"logging.getLogger(\"mem0.graphs.neptune.main\").setLevel(logging.DEBUG)\n",
|
||||
"logging.getLogger(\"mem0.graphs.neptune.base\").setLevel(logging.DEBUG)\n",
|
||||
"logger = logging.getLogger(__name__)\n",
|
||||
"logger.setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"logging.basicConfig(\n",
|
||||
" format=\"%(levelname)s - %(message)s\",\n",
|
||||
" datefmt=\"%Y-%m-%d %H:%M:%S\",\n",
|
||||
" stream=sys.stdout, # Explicitly set output to stdout\n",
|
||||
")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 1
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Setup the Mem0 configuration using:\n",
|
||||
"- openai as the embedder\n",
|
||||
"- Amazon Neptune Analytics instance as a graph store\n",
|
||||
"- OpenSearch as the vector store"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-07-03T20:52:50.958741Z",
|
||||
"start_time": "2025-07-03T20:52:50.955127Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"graph_identifier = os.environ.get(\"GRAPH_ID\")\n",
|
||||
"opensearch_username = os.environ.get(\"OS_USERNAME\")\n",
|
||||
"opensearch_password = os.environ.get(\"OS_PASSWORD\")\n",
|
||||
"config = {\n",
|
||||
" \"embedder\": {\n",
|
||||
" \"provider\": \"openai\",\n",
|
||||
" \"config\": {\"model\": \"text-embedding-3-large\", \"embedding_dims\": 1536},\n",
|
||||
" },\n",
|
||||
" \"graph_store\": {\n",
|
||||
" \"provider\": \"neptune\",\n",
|
||||
" \"config\": {\n",
|
||||
" \"endpoint\": f\"neptune-graph://{graph_identifier}\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"vector_store\": {\n",
|
||||
" \"provider\": \"opensearch\",\n",
|
||||
" \"config\": {\n",
|
||||
" \"collection_name\": \"vector_store\",\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" \"port\": 9200,\n",
|
||||
" \"user\": opensearch_username,\n",
|
||||
" \"password\": opensearch_password,\n",
|
||||
" \"use_ssl\": False,\n",
|
||||
" \"verify_certs\": False,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 2
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Graph Memory initializiation\n",
|
||||
"\n",
|
||||
"Initialize Memgraph as a Graph Memory store:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-07-03T20:52:55.655673Z",
|
||||
"start_time": "2025-07-03T20:52:54.141041Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"m = Memory.from_config(config_dict=config)\n",
|
||||
"\n",
|
||||
"app_id = \"movies\"\n",
|
||||
"user_id = \"alice\"\n",
|
||||
"\n",
|
||||
"m.delete_all(user_id=user_id)"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING - Creating index vector_store, it might take 1-2 minutes...\n",
|
||||
"WARNING - Creating index mem0migrations, it might take 1-2 minutes...\n",
|
||||
"DEBUG - delete_all query=\n",
|
||||
" MATCH (n {user_id: $user_id})\n",
|
||||
" DETACH DELETE n\n",
|
||||
" \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'message': 'Memories deleted successfully!'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 3
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Store memories\n",
|
||||
"\n",
|
||||
"Create memories and store one at a time:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-07-03T20:53:05.338249Z",
|
||||
"start_time": "2025-07-03T20:52:57.528210Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"I'm planning to watch a movie tonight. Any recommendations?\",\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Store inferred memories (default behavior)\n",
|
||||
"result = m.add(messages, user_id=user_id, metadata={\"category\": \"movie_recommendations\"})\n",
|
||||
"\n",
|
||||
"all_results = m.get_all(user_id=user_id)\n",
|
||||
"for n in all_results[\"results\"]:\n",
|
||||
" print(f\"node \\\"{n['memory']}\\\": [hash: {n['hash']}]\")\n",
|
||||
"\n",
|
||||
"for e in all_results[\"relations\"]:\n",
|
||||
" print(f\"edge \\\"{e['source']}\\\" --{e['relationship']}--> \\\"{e['target']}\\\"\")"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"DEBUG - Extracted entities: [{'source': 'alice', 'relationship': 'plans_to_watch', 'destination': 'movie'}]\n",
|
||||
"DEBUG - _search_graph_db\n",
|
||||
" query=\n",
|
||||
" MATCH (n )\n",
|
||||
" WHERE n.user_id = $user_id\n",
|
||||
" WITH n, $n_embedding as n_embedding\n",
|
||||
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
|
||||
" n_embedding,\n",
|
||||
" n,\n",
|
||||
" {metric:\"CosineSimilarity\"}\n",
|
||||
" ) YIELD distance\n",
|
||||
" WITH n, distance as similarity\n",
|
||||
" WHERE similarity >= $threshold\n",
|
||||
" CALL {\n",
|
||||
" WITH n\n",
|
||||
" MATCH (n)-[r]->(m) \n",
|
||||
" RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id\n",
|
||||
" UNION ALL\n",
|
||||
" WITH n\n",
|
||||
" MATCH (m)-[r]->(n) \n",
|
||||
" RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id\n",
|
||||
" }\n",
|
||||
" WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity\n",
|
||||
" RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity\n",
|
||||
" ORDER BY similarity DESC\n",
|
||||
" LIMIT $limit\n",
|
||||
" \n",
|
||||
"DEBUG - Deleted relationships: []\n",
|
||||
"DEBUG - _search_source_node\n",
|
||||
" query=\n",
|
||||
" MATCH (source_candidate )\n",
|
||||
" WHERE source_candidate.user_id = $user_id \n",
|
||||
"\n",
|
||||
" WITH source_candidate, $source_embedding as v_embedding\n",
|
||||
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
|
||||
" v_embedding,\n",
|
||||
" source_candidate,\n",
|
||||
" {metric:\"CosineSimilarity\"}\n",
|
||||
" ) YIELD distance\n",
|
||||
" WITH source_candidate, distance AS cosine_similarity\n",
|
||||
" WHERE cosine_similarity >= $threshold\n",
|
||||
"\n",
|
||||
" WITH source_candidate, cosine_similarity\n",
|
||||
" ORDER BY cosine_similarity DESC\n",
|
||||
" LIMIT 1\n",
|
||||
"\n",
|
||||
" RETURN id(source_candidate), cosine_similarity\n",
|
||||
" \n",
|
||||
"DEBUG - _search_destination_node\n",
|
||||
" query=\n",
|
||||
" MATCH (destination_candidate )\n",
|
||||
" WHERE destination_candidate.user_id = $user_id\n",
|
||||
" \n",
|
||||
" WITH destination_candidate, $destination_embedding as v_embedding\n",
|
||||
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
|
||||
" v_embedding,\n",
|
||||
" destination_candidate, \n",
|
||||
" {metric:\"CosineSimilarity\"}\n",
|
||||
" ) YIELD distance\n",
|
||||
" WITH destination_candidate, distance AS cosine_similarity\n",
|
||||
" WHERE cosine_similarity >= $threshold\n",
|
||||
"\n",
|
||||
" WITH destination_candidate, cosine_similarity\n",
|
||||
" ORDER BY cosine_similarity DESC\n",
|
||||
" LIMIT 1\n",
|
||||
" \n",
|
||||
" RETURN id(destination_candidate), cosine_similarity\n",
|
||||
" \n",
|
||||
"DEBUG - _add_entities:\n",
|
||||
" destination_node_search_result=[]\n",
|
||||
" source_node_search_result=[]\n",
|
||||
" query=\n",
|
||||
" MERGE (n :`__User__` {name: $source_name, user_id: $user_id})\n",
|
||||
" ON CREATE SET n.created = timestamp(),\n",
|
||||
" n.mentions = 1\n",
|
||||
" \n",
|
||||
" ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1\n",
|
||||
" WITH n, $source_embedding as source_embedding\n",
|
||||
" CALL neptune.algo.vectors.upsert(n, source_embedding)\n",
|
||||
" WITH n\n",
|
||||
" MERGE (m :`entertainment` {name: $dest_name, user_id: $user_id})\n",
|
||||
" ON CREATE SET m.created = timestamp(),\n",
|
||||
" m.mentions = 1\n",
|
||||
" \n",
|
||||
" ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1\n",
|
||||
" WITH n, m, $dest_embedding as dest_embedding\n",
|
||||
" CALL neptune.algo.vectors.upsert(m, dest_embedding)\n",
|
||||
" WITH n, m\n",
|
||||
" MERGE (n)-[rel:plans_to_watch]->(m)\n",
|
||||
" ON CREATE SET rel.created = timestamp(), rel.mentions = 1\n",
|
||||
" ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1\n",
|
||||
" RETURN n.name AS source, type(rel) AS relationship, m.name AS target\n",
|
||||
" \n",
|
||||
"DEBUG - Retrieved 1 relationships\n",
|
||||
"node \"Planning to watch a movie tonight\": [hash: bf55418607cfdca4afa311b5fd8496bd]\n",
|
||||
"edge \"alice\" --plans_to_watch--> \"movie\"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 4
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-07-03T20:53:17.755933Z",
|
||||
"start_time": "2025-07-03T20:53:11.568772Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"assistant\",\n",
|
||||
" \"content\": \"How about a thriller movies? They can be quite engaging.\",\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Store inferred memories (default behavior)\n",
|
||||
"result = m.add(messages, user_id=user_id, metadata={\"category\": \"movie_recommendations\"})\n",
|
||||
"\n",
|
||||
"all_results = m.get_all(user_id=user_id)\n",
|
||||
"for n in all_results[\"results\"]:\n",
|
||||
" print(f\"node \\\"{n['memory']}\\\": [hash: {n['hash']}]\")\n",
|
||||
"\n",
|
||||
"for e in all_results[\"relations\"]:\n",
|
||||
" print(f\"edge \\\"{e['source']}\\\" --{e['relationship']}--> \\\"{e['target']}\\\"\")"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"DEBUG - Extracted entities: [{'source': 'thriller_movies', 'relationship': 'is_engaging', 'destination': 'thriller_movies'}]\n",
|
||||
"DEBUG - _search_graph_db\n",
|
||||
" query=\n",
|
||||
" MATCH (n )\n",
|
||||
" WHERE n.user_id = $user_id\n",
|
||||
" WITH n, $n_embedding as n_embedding\n",
|
||||
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
|
||||
" n_embedding,\n",
|
||||
" n,\n",
|
||||
" {metric:\"CosineSimilarity\"}\n",
|
||||
" ) YIELD distance\n",
|
||||
" WITH n, distance as similarity\n",
|
||||
" WHERE similarity >= $threshold\n",
|
||||
" CALL {\n",
|
||||
" WITH n\n",
|
||||
" MATCH (n)-[r]->(m) \n",
|
||||
" RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id\n",
|
||||
" UNION ALL\n",
|
||||
" WITH n\n",
|
||||
" MATCH (m)-[r]->(n) \n",
|
||||
" RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id\n",
|
||||
" }\n",
|
||||
" WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity\n",
|
||||
" RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity\n",
|
||||
" ORDER BY similarity DESC\n",
|
||||
" LIMIT $limit\n",
|
||||
" \n",
|
||||
"DEBUG - Deleted relationships: []\n",
|
||||
"DEBUG - _search_source_node\n",
|
||||
" query=\n",
|
||||
" MATCH (source_candidate )\n",
|
||||
" WHERE source_candidate.user_id = $user_id \n",
|
||||
"\n",
|
||||
" WITH source_candidate, $source_embedding as v_embedding\n",
|
||||
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
|
||||
" v_embedding,\n",
|
||||
" source_candidate,\n",
|
||||
" {metric:\"CosineSimilarity\"}\n",
|
||||
" ) YIELD distance\n",
|
||||
" WITH source_candidate, distance AS cosine_similarity\n",
|
||||
" WHERE cosine_similarity >= $threshold\n",
|
||||
"\n",
|
||||
" WITH source_candidate, cosine_similarity\n",
|
||||
" ORDER BY cosine_similarity DESC\n",
|
||||
" LIMIT 1\n",
|
||||
"\n",
|
||||
" RETURN id(source_candidate), cosine_similarity\n",
|
||||
" \n",
|
||||
"DEBUG - _search_destination_node\n",
|
||||
" query=\n",
|
||||
" MATCH (destination_candidate )\n",
|
||||
" WHERE destination_candidate.user_id = $user_id\n",
|
||||
" \n",
|
||||
" WITH destination_candidate, $destination_embedding as v_embedding\n",
|
||||
" CALL neptune.algo.vectors.distanceByEmbedding(\n",
|
||||
" v_embedding,\n",
|
||||
" destination_candidate, \n",
|
||||
" {metric:\"CosineSimilarity\"}\n",
|
||||
" ) YIELD distance\n",
|
||||
" WITH destination_candidate, distance AS cosine_similarity\n",
|
||||
" WHERE cosine_similarity >= $threshold\n",
|
||||
"\n",
|
||||
" WITH destination_candidate, cosine_similarity\n",
|
||||
" ORDER BY cosine_similarity DESC\n",
|
||||
" LIMIT 1\n",
|
||||
" \n",
|
||||
" RETURN id(destination_candidate), cosine_similarity\n",
|
||||
" \n",
|
||||
"DEBUG - _add_entities:\n",
|
||||
" destination_node_search_result=[{'id(destination_candidate)': '67c49d52-e305-47fe-9fce-2cd5adc5d83c0', 'cosine_similarity': 0.999999}]\n",
|
||||
" source_node_search_result=[{'id(source_candidate)': '67c49d52-e305-47fe-9fce-2cd5adc5d83c0', 'cosine_similarity': 0.999999}]\n",
|
||||
" query=\n",
|
||||
" MATCH (source)\n",
|
||||
" WHERE id(source) = $source_id\n",
|
||||
" SET source.mentions = coalesce(source.mentions, 0) + 1\n",
|
||||
" WITH source\n",
|
||||
" MATCH (destination)\n",
|
||||
" WHERE id(destination) = $destination_id\n",
|
||||
" SET destination.mentions = coalesce(destination.mentions) + 1\n",
|
||||
" MERGE (source)-[r:is_engaging]->(destination)\n",
|
||||
" ON CREATE SET \n",
|
||||
" r.created_at = timestamp(),\n",
|
||||
" r.updated_at = timestamp(),\n",
|
||||
" r.mentions = 1\n",
|
||||
" ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1\n",
|
||||
" RETURN source.name AS source, type(r) AS relationship, destination.name AS target\n",
|
||||
" \n",
|
||||
"DEBUG - Retrieved 3 relationships\n",
|
||||
"node \"Planning to watch a movie tonight\": [hash: bf55418607cfdca4afa311b5fd8496bd]\n",
|
||||
"edge \"thriller_movies\" --is_a_type_of--> \"movie\"\n",
|
||||
"edge \"alice\" --plans_to_watch--> \"movie\"\n",
|
||||
"edge \"thriller_movies\" --is_engaging--> \"thriller_movies\"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 6
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"jupyter": {
|
||||
"is_executing": true
|
||||
},
|
||||
"ExecuteTime": {
|
||||
"start_time": "2025-07-03T20:53:17.775656Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"I'm not a big fan of thriller movies but I love sci-fi movies.\",\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Store inferred memories (default behavior)\n",
|
||||
"result = m.add(messages, user_id=user_id, metadata={\"category\": \"movie_recommendations\"})\n",
|
||||
"\n",
|
||||
"all_results = m.get_all(user_id=user_id)\n",
|
||||
"for n in all_results[\"results\"]:\n",
|
||||
" print(f\"node \\\"{n['memory']}\\\": [hash: {n['hash']}]\")\n",
|
||||
"\n",
|
||||
"for e in all_results[\"relations\"]:\n",
|
||||
" print(f\"edge \\\"{e['source']}\\\" --{e['relationship']}--> \\\"{e['target']}\\\"\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"assistant\",\n",
|
||||
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
|
||||
" },\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Store inferred memories (default behavior)\n",
|
||||
"result = m.add(messages, user_id=user_id, metadata={\"category\": \"movie_recommendations\"})\n",
|
||||
"\n",
|
||||
"all_results = m.get_all(user_id=user_id)\n",
|
||||
"for n in all_results[\"results\"]:\n",
|
||||
" print(f\"node \\\"{n['memory']}\\\": [hash: {n['hash']}]\")\n",
|
||||
"\n",
|
||||
"for e in all_results[\"relations\"]:\n",
|
||||
" print(f\"edge \\\"{e['source']}\\\" --{e['relationship']}--> \\\"{e['target']}\\\"\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Search memories"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"search_results = m.search(\"what does alice love?\", user_id=user_id)\n",
|
||||
"for result in search_results[\"results\"]:\n",
|
||||
" print(f\"\\\"{result['memory']}\\\" [score: {result['score']}]\")\n",
|
||||
"for relation in search_results[\"relations\"]:\n",
|
||||
" print(f\"{relation}\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"m.delete_all(\"user_id\")\n",
|
||||
"m.reset()"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
0
mem0/graphs/__init__.py
Normal file
0
mem0/graphs/__init__.py
Normal file
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
@@ -41,9 +41,43 @@ class MemgraphConfig(BaseModel):
|
||||
return values
|
||||
|
||||
|
||||
class NeptuneConfig(BaseModel):
|
||||
endpoint: Optional[str] = (
|
||||
Field(
|
||||
None,
|
||||
description="Endpoint to connect to a Neptune Analytics Server as neptune-graph://<graphid>",
|
||||
),
|
||||
)
|
||||
base_label: Optional[bool] = Field(None, description="Whether to use base node label __Entity__ for all entities")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_host_port_or_path(cls, values):
|
||||
endpoint = values.get("endpoint")
|
||||
if not endpoint:
|
||||
raise ValueError("Please provide 'endpoint' with the format as 'neptune-graph://<graphid>'.")
|
||||
if endpoint.startswith("neptune-db://"):
|
||||
raise ValueError("neptune-db server is not yet supported")
|
||||
elif endpoint.startswith("neptune-graph://"):
|
||||
# This is a Neptune Analytics Graph
|
||||
graph_identifier = endpoint.replace("neptune-graph://", "")
|
||||
if not graph_identifier.startswith("g-"):
|
||||
raise ValueError("Provide a valid 'graph_identifier'.")
|
||||
values["graph_identifier"] = graph_identifier
|
||||
return values
|
||||
else:
|
||||
raise ValueError(
|
||||
"You must provide an endpoint to create a NeptuneServer as either neptune-db://<endpoint> or neptune-graph://<graphid>"
|
||||
)
|
||||
|
||||
|
||||
class GraphStoreConfig(BaseModel):
|
||||
provider: str = Field(description="Provider of the data store (e.g., 'neo4j')", default="neo4j")
|
||||
config: Neo4jConfig = Field(description="Configuration for the specific data store", default=None)
|
||||
provider: str = Field(
|
||||
description="Provider of the data store (e.g., 'neo4j', 'memgraph', 'neptune')",
|
||||
default="neo4j",
|
||||
)
|
||||
config: Union[Neo4jConfig, MemgraphConfig, NeptuneConfig] = Field(
|
||||
description="Configuration for the specific data store", default=None
|
||||
)
|
||||
llm: Optional[LlmConfig] = Field(description="LLM configuration for querying the graph store", default=None)
|
||||
custom_prompt: Optional[str] = Field(
|
||||
description="Custom prompt to fetch entities from the given text", default=None
|
||||
@@ -56,5 +90,7 @@ class GraphStoreConfig(BaseModel):
|
||||
return Neo4jConfig(**v.model_dump())
|
||||
elif provider == "memgraph":
|
||||
return MemgraphConfig(**v.model_dump())
|
||||
elif provider == "neptune":
|
||||
return NeptuneConfig(**v.model_dump())
|
||||
else:
|
||||
raise ValueError(f"Unsupported graph store provider: {provider}")
|
||||
|
||||
0
mem0/graphs/neptune/__init__.py
Normal file
0
mem0/graphs/neptune/__init__.py
Normal file
410
mem0/graphs/neptune/base.py
Normal file
410
mem0/graphs/neptune/base.py
Normal file
@@ -0,0 +1,410 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from mem0.memory.utils import format_entities
|
||||
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
except ImportError:
|
||||
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
|
||||
|
||||
from mem0.graphs.tools import (
|
||||
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
DELETE_MEMORY_TOOL_GRAPH,
|
||||
EXTRACT_ENTITIES_STRUCT_TOOL,
|
||||
EXTRACT_ENTITIES_TOOL,
|
||||
RELATIONS_STRUCT_TOOL,
|
||||
RELATIONS_TOOL,
|
||||
)
|
||||
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NeptuneBase(ABC):
|
||||
"""
|
||||
Abstract base class for neptune (neptune analytics and neptune db) calls using OpenCypher
|
||||
to store/retrieve data
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _create_embedding_model(config):
|
||||
"""
|
||||
:return: the Embedder model used for memory store
|
||||
"""
|
||||
return EmbedderFactory.create(
|
||||
config.embedder.provider,
|
||||
config.embedder.config,
|
||||
{"enable_embeddings": True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_llm(config, llm_provider):
|
||||
"""
|
||||
:return: the llm model used for memory store
|
||||
"""
|
||||
return LlmFactory.create(llm_provider, config.llm.config)
|
||||
|
||||
def add(self, data, filters):
|
||||
"""
|
||||
Adds data to the graph.
|
||||
|
||||
Args:
|
||||
data (str): The data to add to the graph.
|
||||
filters (dict): A dictionary containing filters to be applied during the addition.
|
||||
"""
|
||||
entity_type_map = self._retrieve_nodes_from_data(data, filters)
|
||||
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
|
||||
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
|
||||
|
||||
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
|
||||
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
|
||||
|
||||
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
||||
|
||||
def _retrieve_nodes_from_data(self, data, filters):
|
||||
"""
|
||||
Extract all entities mentioned in the query.
|
||||
"""
|
||||
_tools = [EXTRACT_ENTITIES_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
|
||||
search_results = self.llm.generate_response(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
],
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
entity_type_map = {}
|
||||
|
||||
try:
|
||||
for tool_call in search_results["tool_calls"]:
|
||||
if tool_call["name"] != "extract_entities":
|
||||
continue
|
||||
for item in tool_call["arguments"]["entities"]:
|
||||
entity_type_map[item["entity"]] = item["entity_type"]
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
|
||||
)
|
||||
|
||||
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
|
||||
return entity_type_map
|
||||
|
||||
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
||||
"""
|
||||
Establish relations among the extracted nodes.
|
||||
"""
|
||||
if self.config.graph_store.custom_prompt:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}",
|
||||
},
|
||||
]
|
||||
|
||||
_tools = [RELATIONS_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [RELATIONS_STRUCT_TOOL]
|
||||
|
||||
extracted_entities = self.llm.generate_response(
|
||||
messages=messages,
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
entities = []
|
||||
if extracted_entities["tool_calls"]:
|
||||
entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
|
||||
|
||||
entities = self._remove_spaces_from_entities(entities)
|
||||
logger.debug(f"Extracted entities: {entities}")
|
||||
return entities
|
||||
|
||||
def _remove_spaces_from_entities(self, entity_list):
|
||||
for item in entity_list:
|
||||
item["source"] = item["source"].lower().replace(" ", "_")
|
||||
item["relationship"] = item["relationship"].lower().replace(" ", "_")
|
||||
item["destination"] = item["destination"].lower().replace(" ", "_")
|
||||
return entity_list
|
||||
|
||||
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
||||
"""
|
||||
Get the entities to be deleted from the search output.
|
||||
"""
|
||||
|
||||
search_output_string = format_entities(search_output)
|
||||
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
|
||||
|
||||
_tools = [DELETE_MEMORY_TOOL_GRAPH]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [
|
||||
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
]
|
||||
|
||||
memory_updates = self.llm.generate_response(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
to_be_deleted = []
|
||||
for item in memory_updates["tool_calls"]:
|
||||
if item["name"] == "delete_graph_memory":
|
||||
to_be_deleted.append(item["arguments"])
|
||||
# in case if it is not in the correct format
|
||||
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
|
||||
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
||||
return to_be_deleted
|
||||
|
||||
def _delete_entities(self, to_be_deleted, user_id):
|
||||
"""
|
||||
Delete the entities from the graph.
|
||||
"""
|
||||
|
||||
results = []
|
||||
for item in to_be_deleted:
|
||||
source = item["source"]
|
||||
destination = item["destination"]
|
||||
relationship = item["relationship"]
|
||||
|
||||
# Delete the specific relationship between nodes
|
||||
cypher, params = self._delete_entities_cypher(source, destination, relationship, user_id)
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
def _delete_entities_cypher(self, source, destination, relationship, user_id):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for deleting entities in the graph DB
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def _add_entities(self, to_be_added, user_id, entity_type_map):
|
||||
"""
|
||||
Add the new entities to the graph. Merge the nodes if they already exist.
|
||||
"""
|
||||
|
||||
results = []
|
||||
for item in to_be_added:
|
||||
# entities
|
||||
source = item["source"]
|
||||
destination = item["destination"]
|
||||
relationship = item["relationship"]
|
||||
|
||||
# types
|
||||
source_type = entity_type_map.get(source, "__User__")
|
||||
destination_type = entity_type_map.get(destination, "__User__")
|
||||
|
||||
# embeddings
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
dest_embedding = self.embedding_model.embed(destination)
|
||||
|
||||
# search for the nodes with the closest embeddings
|
||||
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
|
||||
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
|
||||
|
||||
cypher, params = self._add_entities_cypher(
|
||||
source_node_search_result,
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination_node_search_result,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id,
|
||||
)
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
def _add_entities_cypher(
|
||||
self,
|
||||
source_node_list,
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination_node_list,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
||||
"""
|
||||
pass
|
||||
|
||||
def search(self, query, filters, limit=100):
|
||||
"""
|
||||
Search for memories and related graph data.
|
||||
|
||||
Args:
|
||||
query (str): Query to search for.
|
||||
filters (dict): A dictionary containing filters to be applied during the search.
|
||||
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing:
|
||||
- "contexts": List of search results from the base data store.
|
||||
- "entities": List of related graph data based on the query.
|
||||
"""
|
||||
|
||||
entity_type_map = self._retrieve_nodes_from_data(query, filters)
|
||||
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||
|
||||
if not search_output:
|
||||
return []
|
||||
|
||||
search_outputs_sequence = [
|
||||
[item["source"], item["relationship"], item["destination"]] for item in search_output
|
||||
]
|
||||
bm25 = BM25Okapi(search_outputs_sequence)
|
||||
|
||||
tokenized_query = query.split(" ")
|
||||
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
|
||||
|
||||
search_results = []
|
||||
for item in reranked_results:
|
||||
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
|
||||
|
||||
return search_results
|
||||
|
||||
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
||||
cypher, params = self._search_source_node_cypher(source_embedding, user_id, threshold)
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _search_source_node_cypher(self, source_embedding, user_id, threshold):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for source nodes
|
||||
"""
|
||||
pass
|
||||
|
||||
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
||||
cypher, params = self._search_destination_node_cypher(destination_embedding, user_id, threshold)
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _search_destination_node_cypher(self, destination_embedding, user_id, threshold):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for destination nodes
|
||||
"""
|
||||
pass
|
||||
|
||||
def delete_all(self, filters):
|
||||
cypher, params = self._delete_all_cypher(filters)
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
@abstractmethod
|
||||
def _delete_all_cypher(self, filters):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to delete all edges/nodes in the memory store
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_all(self, filters, limit=100):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on filtering criteria.
|
||||
|
||||
Args:
|
||||
filters (dict): A dictionary containing filters to be applied during the retrieval.
|
||||
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
||||
Returns:
|
||||
list: A list of dictionaries, each containing:
|
||||
- 'contexts': The base data store response for each memory.
|
||||
- 'entities': A list of strings representing the nodes and relationships
|
||||
"""
|
||||
|
||||
# return all nodes and relationships
|
||||
query, params = self._get_all_cypher(filters, limit)
|
||||
results = self.graph.query(query, params=params)
|
||||
|
||||
final_results = []
|
||||
for result in results:
|
||||
final_results.append(
|
||||
{
|
||||
"source": result["source"],
|
||||
"relationship": result["relationship"],
|
||||
"target": result["target"],
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Retrieved {len(final_results)} relationships")
|
||||
|
||||
return final_results
|
||||
|
||||
@abstractmethod
|
||||
def _get_all_cypher(self, filters, limit):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to get all edges/nodes in the memory store
|
||||
"""
|
||||
pass
|
||||
|
||||
def _search_graph_db(self, node_list, filters, limit=100):
|
||||
"""
|
||||
Search similar nodes among and their respective incoming and outgoing relations.
|
||||
"""
|
||||
result_relations = []
|
||||
|
||||
for node in node_list:
|
||||
n_embedding = self.embedding_model.embed(node)
|
||||
cypher_query, params = self._search_graph_db_cypher(n_embedding, filters, limit)
|
||||
ans = self.graph.query(cypher_query, params=params)
|
||||
result_relations.extend(ans)
|
||||
|
||||
return result_relations
|
||||
|
||||
@abstractmethod
|
||||
def _search_graph_db_cypher(self, n_embedding, filters, limit):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for similar nodes in the memory store
|
||||
"""
|
||||
pass
|
||||
|
||||
# Reset is not defined in base.py
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the graph by clearing all nodes and relationships.
|
||||
|
||||
link: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/neptune-graph/client/reset_graph.html
|
||||
"""
|
||||
|
||||
logger.warning("Clearing graph...")
|
||||
graph_id = self.graph.graph_identifier
|
||||
self.graph.client.reset_graph(
|
||||
graphIdentifier=graph_id,
|
||||
skipSnapshot=True,
|
||||
)
|
||||
waiter = self.graph.client.get_waiter("graph_available")
|
||||
waiter.wait(graphIdentifier=graph_id, WaiterConfig={"Delay": 10, "MaxAttempts": 60})
|
||||
372
mem0/graphs/neptune/main.py
Normal file
372
mem0/graphs/neptune/main.py
Normal file
@@ -0,0 +1,372 @@
|
||||
import logging
|
||||
from .base import NeptuneBase
|
||||
|
||||
try:
|
||||
from langchain_aws import NeptuneAnalyticsGraph
|
||||
except ImportError:
|
||||
raise ImportError("langchain_aws is not installed. Please install it using 'make install_all'.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryGraph(NeptuneBase):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
self.graph = None
|
||||
endpoint = self.config.graph_store.config.endpoint
|
||||
if endpoint and endpoint.startswith("neptune-graph://"):
|
||||
graph_identifier = endpoint.replace("neptune-graph://", "")
|
||||
self.graph = NeptuneAnalyticsGraph(graph_identifier)
|
||||
|
||||
if not self.graph:
|
||||
raise ValueError("Unable to create a Neptune client: missing 'endpoint' in config")
|
||||
|
||||
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
|
||||
|
||||
self.embedding_model = NeptuneBase._create_embedding_model(self.config)
|
||||
|
||||
self.llm_provider = "openai_structured"
|
||||
if self.config.llm.provider:
|
||||
self.llm_provider = self.config.llm.provider
|
||||
if self.config.graph_store.llm:
|
||||
self.llm_provider = self.config.graph_store.llm.provider
|
||||
|
||||
self.llm = NeptuneBase._create_llm(self.config, self.llm_provider)
|
||||
self.user_id = None
|
||||
self.threshold = 0.7
|
||||
|
||||
def _delete_entities_cypher(self, source, destination, relationship, user_id):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for deleting entities in the graph DB
|
||||
|
||||
:param source: source node
|
||||
:param destination: destination node
|
||||
:param relationship: relationship label
|
||||
:param user_id: user_id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
|
||||
-[r:{relationship}]->
|
||||
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
|
||||
DELETE r
|
||||
RETURN
|
||||
n.name AS source,
|
||||
m.name AS target,
|
||||
type(r) AS relationship
|
||||
"""
|
||||
params = {
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"user_id": user_id,
|
||||
}
|
||||
logger.debug(f"_delete_entities\n query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _add_entities_cypher(
|
||||
self,
|
||||
source_node_list,
|
||||
source,
|
||||
source_embedding,
|
||||
source_type,
|
||||
destination_node_list,
|
||||
destination,
|
||||
dest_embedding,
|
||||
destination_type,
|
||||
relationship,
|
||||
user_id,
|
||||
):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
||||
|
||||
:param source_node_list: list of source nodes
|
||||
:param source: source node name
|
||||
:param source_embedding: source node embedding
|
||||
:param source_type: source node label
|
||||
:param destination_node_list: list of dest nodes
|
||||
:param destination: destination name
|
||||
:param dest_embedding: destination embedding
|
||||
:param destination_type: destination node label
|
||||
:param relationship: relationship label
|
||||
:param user_id: user id to use
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
||||
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
||||
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
||||
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
||||
|
||||
# Refactor this code with the graph_memory.py implementation
|
||||
if not destination_node_list and source_node_list:
|
||||
cypher = f"""
|
||||
MATCH (source)
|
||||
WHERE id(source) = $source_id
|
||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}})
|
||||
ON CREATE SET
|
||||
destination.created = timestamp(),
|
||||
destination.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET
|
||||
destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
WITH source, destination, $dest_embedding as dest_embedding
|
||||
CALL neptune.algo.vectors.upsert(destination, dest_embedding)
|
||||
WITH source, destination
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET
|
||||
r.mentions = coalesce(r.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_id": source_node_list[0]["id(source_candidate)"],
|
||||
"destination_name": destination,
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
elif destination_node_list and not source_node_list:
|
||||
cypher = f"""
|
||||
MATCH (destination)
|
||||
WHERE id(destination) = $destination_id
|
||||
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
WITH destination
|
||||
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
|
||||
ON CREATE SET
|
||||
source.created = timestamp(),
|
||||
source.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET
|
||||
source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source, destination, $source_embedding as source_embedding
|
||||
CALL neptune.algo.vectors.upsert(source, source_embedding)
|
||||
WITH source, destination
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET
|
||||
r.mentions = coalesce(r.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
|
||||
params = {
|
||||
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
||||
"source_name": source,
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
elif source_node_list and destination_node_list:
|
||||
cypher = f"""
|
||||
MATCH (source)
|
||||
WHERE id(source) = $source_id
|
||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
MATCH (destination)
|
||||
WHERE id(destination) = $destination_id
|
||||
SET destination.mentions = coalesce(destination.mentions) + 1
|
||||
MERGE (source)-[r:{relationship}]->(destination)
|
||||
ON CREATE SET
|
||||
r.created_at = timestamp(),
|
||||
r.updated_at = timestamp(),
|
||||
r.mentions = 1
|
||||
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
params = {
|
||||
"source_id": source_node_list[0]["id(source_candidate)"],
|
||||
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
||||
"user_id": user_id,
|
||||
}
|
||||
else:
|
||||
cypher = f"""
|
||||
MERGE (n {source_label} {{name: $source_name, user_id: $user_id}})
|
||||
ON CREATE SET n.created = timestamp(),
|
||||
n.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1
|
||||
WITH n, $source_embedding as source_embedding
|
||||
CALL neptune.algo.vectors.upsert(n, source_embedding)
|
||||
WITH n
|
||||
MERGE (m {destination_label} {{name: $dest_name, user_id: $user_id}})
|
||||
ON CREATE SET m.created = timestamp(),
|
||||
m.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1
|
||||
WITH n, m, $dest_embedding as dest_embedding
|
||||
CALL neptune.algo.vectors.upsert(m, dest_embedding)
|
||||
WITH n, m
|
||||
MERGE (n)-[rel:{relationship}]->(m)
|
||||
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
|
||||
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
|
||||
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
|
||||
"""
|
||||
params = {
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"source_embedding": source_embedding,
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": user_id,
|
||||
}
|
||||
logger.debug(
|
||||
f"_add_entities:\n destination_node_search_result={destination_node_list}\n source_node_search_result={source_node_list}\n query={cypher}"
|
||||
)
|
||||
return cypher, params
|
||||
|
||||
def _search_source_node_cypher(self, source_embedding, user_id, threshold):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for source nodes
|
||||
|
||||
:param source_embedding: source vector
|
||||
:param user_id: user_id to use
|
||||
:param threshold: the threshold for similarity
|
||||
:return: str, dict
|
||||
"""
|
||||
cypher = f"""
|
||||
MATCH (source_candidate {self.node_label})
|
||||
WHERE source_candidate.user_id = $user_id
|
||||
|
||||
WITH source_candidate, $source_embedding as v_embedding
|
||||
CALL neptune.algo.vectors.distanceByEmbedding(
|
||||
v_embedding,
|
||||
source_candidate,
|
||||
{{metric:"CosineSimilarity"}}
|
||||
) YIELD distance
|
||||
WITH source_candidate, distance AS cosine_similarity
|
||||
WHERE cosine_similarity >= $threshold
|
||||
|
||||
WITH source_candidate, cosine_similarity
|
||||
ORDER BY cosine_similarity DESC
|
||||
LIMIT 1
|
||||
|
||||
RETURN id(source_candidate), cosine_similarity
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
"threshold": threshold,
|
||||
}
|
||||
logger.debug(f"_search_source_node\n query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _search_destination_node_cypher(self, destination_embedding, user_id, threshold):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for destination nodes
|
||||
|
||||
:param source_embedding: source vector
|
||||
:param user_id: user_id to use
|
||||
:param threshold: the threshold for similarity
|
||||
:return: str, dict
|
||||
"""
|
||||
cypher = f"""
|
||||
MATCH (destination_candidate {self.node_label})
|
||||
WHERE destination_candidate.user_id = $user_id
|
||||
|
||||
WITH destination_candidate, $destination_embedding as v_embedding
|
||||
CALL neptune.algo.vectors.distanceByEmbedding(
|
||||
v_embedding,
|
||||
destination_candidate,
|
||||
{{metric:"CosineSimilarity"}}
|
||||
) YIELD distance
|
||||
WITH destination_candidate, distance AS cosine_similarity
|
||||
WHERE cosine_similarity >= $threshold
|
||||
|
||||
WITH destination_candidate, cosine_similarity
|
||||
ORDER BY cosine_similarity DESC
|
||||
LIMIT 1
|
||||
|
||||
RETURN id(destination_candidate), cosine_similarity
|
||||
"""
|
||||
params = {
|
||||
"destination_embedding": destination_embedding,
|
||||
"user_id": user_id,
|
||||
"threshold": threshold,
|
||||
}
|
||||
|
||||
logger.debug(f"_search_destination_node\n query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _delete_all_cypher(self, filters):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to delete all edges/nodes in the memory store
|
||||
|
||||
:param filters: search filters
|
||||
:return: str, dict
|
||||
"""
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id}})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
params = {"user_id": filters["user_id"]}
|
||||
|
||||
logger.debug(f"delete_all query={cypher}")
|
||||
return cypher, params
|
||||
|
||||
def _get_all_cypher(self, filters, limit):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to get all edges/nodes in the memory store
|
||||
|
||||
:param filters: search filters
|
||||
:param limit: return limit
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
LIMIT $limit
|
||||
"""
|
||||
params = {"user_id": filters["user_id"], "limit": limit}
|
||||
return cypher, params
|
||||
|
||||
def _search_graph_db_cypher(self, n_embedding, filters, limit):
|
||||
"""
|
||||
Returns the OpenCypher query and parameters to search for similar nodes in the memory store
|
||||
|
||||
:param n_embedding: node vector
|
||||
:param filters: search filters
|
||||
:param limit: return limit
|
||||
:return: str, dict
|
||||
"""
|
||||
|
||||
cypher_query = f"""
|
||||
MATCH (n {self.node_label})
|
||||
WHERE n.user_id = $user_id
|
||||
WITH n, $n_embedding as n_embedding
|
||||
CALL neptune.algo.vectors.distanceByEmbedding(
|
||||
n_embedding,
|
||||
n,
|
||||
{{metric:"CosineSimilarity"}}
|
||||
) YIELD distance
|
||||
WITH n, distance as similarity
|
||||
WHERE similarity >= $threshold
|
||||
CALL {{
|
||||
WITH n
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id
|
||||
UNION ALL
|
||||
WITH n
|
||||
MATCH (m)-[r]->(n)
|
||||
RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id
|
||||
}}
|
||||
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity
|
||||
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
params = {
|
||||
"n_embedding": n_embedding,
|
||||
"threshold": self.threshold,
|
||||
"user_id": filters["user_id"],
|
||||
"limit": limit,
|
||||
}
|
||||
logger.debug(f"_search_graph_db\n query={cypher_query}")
|
||||
|
||||
return cypher_query, params
|
||||
@@ -621,3 +621,12 @@ class MemoryGraph:
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
|
||||
# Reset is not defined in base.py
|
||||
def reset(self):
|
||||
"""Reset the graph by clearing all nodes and relationships."""
|
||||
logger.warning("Clearing graph...")
|
||||
cypher_query = """
|
||||
MATCH (n) DETACH DELETE n
|
||||
"""
|
||||
return self.graph.query(cypher_query)
|
||||
|
||||
@@ -138,6 +138,8 @@ class Memory(MemoryBase):
|
||||
if self.config.graph_store.config:
|
||||
if self.config.graph_store.provider == "memgraph":
|
||||
from mem0.memory.memgraph_memory import MemoryGraph
|
||||
elif self.config.graph_store.provider == "neptune":
|
||||
from mem0.graphs.neptune.main import MemoryGraph
|
||||
else:
|
||||
from mem0.memory.graph_memory import MemoryGraph
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
graph = [
|
||||
"langchain-neo4j>=0.4.0",
|
||||
"langchain-aws>=0.2.23",
|
||||
"neo4j>=5.23.1",
|
||||
"rank-bm25>=0.2.2",
|
||||
]
|
||||
|
||||
335
tests/memory/test_neptune_memory.py
Normal file
335
tests/memory/test_neptune_memory.py
Normal file
@@ -0,0 +1,335 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from mem0.graphs.neptune.main import MemoryGraph
|
||||
from mem0.graphs.neptune.base import NeptuneBase
|
||||
|
||||
|
||||
class TestNeptuneMemory(unittest.TestCase):
|
||||
"""Test suite for the Neptune Memory implementation."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures before each test method."""
|
||||
|
||||
# Create a mock config
|
||||
self.config = MagicMock()
|
||||
self.config.graph_store.config.endpoint = "neptune-graph://test-graph"
|
||||
self.config.graph_store.config.base_label = True
|
||||
self.config.llm.provider = "openai_structured"
|
||||
self.config.graph_store.llm = None
|
||||
self.config.graph_store.custom_prompt = None
|
||||
|
||||
# Create mock for NeptuneAnalyticsGraph
|
||||
self.mock_graph = MagicMock()
|
||||
self.mock_graph.client.get_graph.return_value = {"status": "AVAILABLE"}
|
||||
|
||||
# Create mocks for static methods
|
||||
self.mock_embedding_model = MagicMock()
|
||||
self.mock_llm = MagicMock()
|
||||
|
||||
# Patch the necessary components
|
||||
self.neptune_analytics_graph_patcher = patch("mem0.graphs.neptune.main.NeptuneAnalyticsGraph")
|
||||
self.mock_neptune_analytics_graph = self.neptune_analytics_graph_patcher.start()
|
||||
self.mock_neptune_analytics_graph.return_value = self.mock_graph
|
||||
|
||||
# Patch the static methods
|
||||
self.create_embedding_model_patcher = patch.object(NeptuneBase, "_create_embedding_model")
|
||||
self.mock_create_embedding_model = self.create_embedding_model_patcher.start()
|
||||
self.mock_create_embedding_model.return_value = self.mock_embedding_model
|
||||
|
||||
self.create_llm_patcher = patch.object(NeptuneBase, "_create_llm")
|
||||
self.mock_create_llm = self.create_llm_patcher.start()
|
||||
self.mock_create_llm.return_value = self.mock_llm
|
||||
|
||||
# Create the MemoryGraph instance
|
||||
self.memory_graph = MemoryGraph(self.config)
|
||||
|
||||
# Set up common test data
|
||||
self.user_id = "test_user"
|
||||
self.test_filters = {"user_id": self.user_id}
|
||||
|
||||
def tearDown(self):
|
||||
"""Tear down test fixtures after each test method."""
|
||||
self.neptune_analytics_graph_patcher.stop()
|
||||
self.create_embedding_model_patcher.stop()
|
||||
self.create_llm_patcher.stop()
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test that the MemoryGraph is initialized correctly."""
|
||||
self.assertEqual(self.memory_graph.graph, self.mock_graph)
|
||||
self.assertEqual(self.memory_graph.embedding_model, self.mock_embedding_model)
|
||||
self.assertEqual(self.memory_graph.llm, self.mock_llm)
|
||||
self.assertEqual(self.memory_graph.llm_provider, "openai_structured")
|
||||
self.assertEqual(self.memory_graph.node_label, ":`__Entity__`")
|
||||
self.assertEqual(self.memory_graph.threshold, 0.7)
|
||||
|
||||
def test_init(self):
|
||||
"""Test the class init functions"""
|
||||
|
||||
# Create a mock config with bad endpoint
|
||||
config_no_endpoint = MagicMock()
|
||||
config_no_endpoint.graph_store.config.endpoint = None
|
||||
|
||||
# Create the MemoryGraph instance
|
||||
with pytest.raises(ValueError):
|
||||
MemoryGraph(config_no_endpoint)
|
||||
|
||||
# Create a mock config with bad endpoint
|
||||
config_ndb_endpoint = MagicMock()
|
||||
config_ndb_endpoint.graph_store.config.endpoint = "neptune-db://test-graph"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MemoryGraph(config_ndb_endpoint)
|
||||
|
||||
def test_add_method(self):
|
||||
"""Test the add method with mocked components."""
|
||||
|
||||
# Mock the necessary methods that add() calls
|
||||
self.memory_graph._retrieve_nodes_from_data = MagicMock(return_value={"alice": "person", "bob": "person"})
|
||||
self.memory_graph._establish_nodes_relations_from_data = MagicMock(
|
||||
return_value=[{"source": "alice", "relationship": "knows", "destination": "bob"}]
|
||||
)
|
||||
self.memory_graph._search_graph_db = MagicMock(return_value=[])
|
||||
self.memory_graph._get_delete_entities_from_search_output = MagicMock(return_value=[])
|
||||
self.memory_graph._delete_entities = MagicMock(return_value=[])
|
||||
self.memory_graph._add_entities = MagicMock(
|
||||
return_value=[{"source": "alice", "relationship": "knows", "target": "bob"}]
|
||||
)
|
||||
|
||||
# Call the add method
|
||||
result = self.memory_graph.add("Alice knows Bob", self.test_filters)
|
||||
|
||||
# Verify the method calls
|
||||
self.memory_graph._retrieve_nodes_from_data.assert_called_once_with("Alice knows Bob", self.test_filters)
|
||||
self.memory_graph._establish_nodes_relations_from_data.assert_called_once()
|
||||
self.memory_graph._search_graph_db.assert_called_once()
|
||||
self.memory_graph._get_delete_entities_from_search_output.assert_called_once()
|
||||
self.memory_graph._delete_entities.assert_called_once_with([], self.user_id)
|
||||
self.memory_graph._add_entities.assert_called_once()
|
||||
|
||||
# Check the result structure
|
||||
self.assertIn("deleted_entities", result)
|
||||
self.assertIn("added_entities", result)
|
||||
|
||||
def test_search_method(self):
|
||||
"""Test the search method with mocked components."""
|
||||
# Mock the necessary methods that search() calls
|
||||
self.memory_graph._retrieve_nodes_from_data = MagicMock(return_value={"alice": "person"})
|
||||
|
||||
# Mock search results
|
||||
mock_search_results = [
|
||||
{"source": "alice", "relationship": "knows", "destination": "bob"},
|
||||
{"source": "alice", "relationship": "works_with", "destination": "charlie"},
|
||||
]
|
||||
self.memory_graph._search_graph_db = MagicMock(return_value=mock_search_results)
|
||||
|
||||
# Mock BM25Okapi
|
||||
with patch("mem0.graphs.neptune.base.BM25Okapi") as mock_bm25:
|
||||
mock_bm25_instance = MagicMock()
|
||||
mock_bm25.return_value = mock_bm25_instance
|
||||
|
||||
# Mock get_top_n to return reranked results
|
||||
reranked_results = [["alice", "knows", "bob"], ["alice", "works_with", "charlie"]]
|
||||
mock_bm25_instance.get_top_n.return_value = reranked_results
|
||||
|
||||
# Call the search method
|
||||
result = self.memory_graph.search("Find Alice", self.test_filters, limit=5)
|
||||
|
||||
# Verify the method calls
|
||||
self.memory_graph._retrieve_nodes_from_data.assert_called_once_with("Find Alice", self.test_filters)
|
||||
self.memory_graph._search_graph_db.assert_called_once_with(node_list=["alice"], filters=self.test_filters)
|
||||
|
||||
# Check the result structure
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["source"], "alice")
|
||||
self.assertEqual(result[0]["relationship"], "knows")
|
||||
self.assertEqual(result[0]["destination"], "bob")
|
||||
|
||||
def test_get_all_method(self):
|
||||
"""Test the get_all method."""
|
||||
|
||||
# Mock the _get_all_cypher method
|
||||
mock_cypher = "MATCH (n) RETURN n"
|
||||
mock_params = {"user_id": self.user_id, "limit": 10}
|
||||
self.memory_graph._get_all_cypher = MagicMock(return_value=(mock_cypher, mock_params))
|
||||
|
||||
# Mock the graph.query result
|
||||
mock_query_result = [
|
||||
{"source": "alice", "relationship": "knows", "target": "bob"},
|
||||
{"source": "bob", "relationship": "works_with", "target": "charlie"},
|
||||
]
|
||||
self.mock_graph.query.return_value = mock_query_result
|
||||
|
||||
# Call the get_all method
|
||||
result = self.memory_graph.get_all(self.test_filters, limit=10)
|
||||
|
||||
# Verify the method calls
|
||||
self.memory_graph._get_all_cypher.assert_called_once_with(self.test_filters, 10)
|
||||
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
|
||||
|
||||
# Check the result structure
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["source"], "alice")
|
||||
self.assertEqual(result[0]["relationship"], "knows")
|
||||
self.assertEqual(result[0]["target"], "bob")
|
||||
|
||||
def test_delete_all_method(self):
|
||||
"""Test the delete_all method."""
|
||||
# Mock the _delete_all_cypher method
|
||||
mock_cypher = "MATCH (n) DETACH DELETE n"
|
||||
mock_params = {"user_id": self.user_id}
|
||||
self.memory_graph._delete_all_cypher = MagicMock(return_value=(mock_cypher, mock_params))
|
||||
|
||||
# Call the delete_all method
|
||||
self.memory_graph.delete_all(self.test_filters)
|
||||
|
||||
# Verify the method calls
|
||||
self.memory_graph._delete_all_cypher.assert_called_once_with(self.test_filters)
|
||||
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
|
||||
|
||||
def test_search_source_node(self):
|
||||
"""Test the _search_source_node method."""
|
||||
# Mock embedding
|
||||
mock_embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
# Mock the _search_source_node_cypher method
|
||||
mock_cypher = "MATCH (n) RETURN n"
|
||||
mock_params = {"source_embedding": mock_embedding, "user_id": self.user_id, "threshold": 0.9}
|
||||
self.memory_graph._search_source_node_cypher = MagicMock(return_value=(mock_cypher, mock_params))
|
||||
|
||||
# Mock the graph.query result
|
||||
mock_query_result = [{"id(source_candidate)": 123, "cosine_similarity": 0.95}]
|
||||
self.mock_graph.query.return_value = mock_query_result
|
||||
|
||||
# Call the _search_source_node method
|
||||
result = self.memory_graph._search_source_node(mock_embedding, self.user_id, threshold=0.9)
|
||||
|
||||
# Verify the method calls
|
||||
self.memory_graph._search_source_node_cypher.assert_called_once_with(mock_embedding, self.user_id, 0.9)
|
||||
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
|
||||
|
||||
# Check the result
|
||||
self.assertEqual(result, mock_query_result)
|
||||
|
||||
def test_search_destination_node(self):
|
||||
"""Test the _search_destination_node method."""
|
||||
# Mock embedding
|
||||
mock_embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
# Mock the _search_destination_node_cypher method
|
||||
mock_cypher = "MATCH (n) RETURN n"
|
||||
mock_params = {"destination_embedding": mock_embedding, "user_id": self.user_id, "threshold": 0.9}
|
||||
self.memory_graph._search_destination_node_cypher = MagicMock(return_value=(mock_cypher, mock_params))
|
||||
|
||||
# Mock the graph.query result
|
||||
mock_query_result = [{"id(destination_candidate)": 456, "cosine_similarity": 0.92}]
|
||||
self.mock_graph.query.return_value = mock_query_result
|
||||
|
||||
# Call the _search_destination_node method
|
||||
result = self.memory_graph._search_destination_node(mock_embedding, self.user_id, threshold=0.9)
|
||||
|
||||
# Verify the method calls
|
||||
self.memory_graph._search_destination_node_cypher.assert_called_once_with(mock_embedding, self.user_id, 0.9)
|
||||
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
|
||||
|
||||
# Check the result
|
||||
self.assertEqual(result, mock_query_result)
|
||||
|
||||
def test_search_graph_db(self):
|
||||
"""Test the _search_graph_db method."""
|
||||
# Mock node list
|
||||
node_list = ["alice", "bob"]
|
||||
|
||||
# Mock embedding
|
||||
mock_embedding = [0.1, 0.2, 0.3]
|
||||
self.mock_embedding_model.embed.return_value = mock_embedding
|
||||
|
||||
# Mock the _search_graph_db_cypher method
|
||||
mock_cypher = "MATCH (n) RETURN n"
|
||||
mock_params = {"n_embedding": mock_embedding, "user_id": self.user_id, "threshold": 0.7, "limit": 10}
|
||||
self.memory_graph._search_graph_db_cypher = MagicMock(return_value=(mock_cypher, mock_params))
|
||||
|
||||
# Mock the graph.query results
|
||||
mock_query_result1 = [{"source": "alice", "relationship": "knows", "destination": "bob"}]
|
||||
mock_query_result2 = [{"source": "bob", "relationship": "works_with", "destination": "charlie"}]
|
||||
self.mock_graph.query.side_effect = [mock_query_result1, mock_query_result2]
|
||||
|
||||
# Call the _search_graph_db method
|
||||
result = self.memory_graph._search_graph_db(node_list, self.test_filters, limit=10)
|
||||
|
||||
# Verify the method calls
|
||||
self.assertEqual(self.mock_embedding_model.embed.call_count, 2)
|
||||
self.assertEqual(self.memory_graph._search_graph_db_cypher.call_count, 2)
|
||||
self.assertEqual(self.mock_graph.query.call_count, 2)
|
||||
|
||||
# Check the result
|
||||
expected_result = mock_query_result1 + mock_query_result2
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_add_entities(self):
|
||||
"""Test the _add_entities method."""
|
||||
# Mock data
|
||||
to_be_added = [{"source": "alice", "relationship": "knows", "destination": "bob"}]
|
||||
entity_type_map = {"alice": "person", "bob": "person"}
|
||||
|
||||
# Mock embeddings
|
||||
mock_embedding = [0.1, 0.2, 0.3]
|
||||
self.mock_embedding_model.embed.return_value = mock_embedding
|
||||
|
||||
# Mock search results
|
||||
mock_source_search = [{"id(source_candidate)": 123, "cosine_similarity": 0.95}]
|
||||
mock_dest_search = [{"id(destination_candidate)": 456, "cosine_similarity": 0.92}]
|
||||
|
||||
# Mock the search methods
|
||||
self.memory_graph._search_source_node = MagicMock(return_value=mock_source_search)
|
||||
self.memory_graph._search_destination_node = MagicMock(return_value=mock_dest_search)
|
||||
|
||||
# Mock the _add_entities_cypher method
|
||||
mock_cypher = "MATCH (n) RETURN n"
|
||||
mock_params = {"source_id": 123, "destination_id": 456}
|
||||
self.memory_graph._add_entities_cypher = MagicMock(return_value=(mock_cypher, mock_params))
|
||||
|
||||
# Mock the graph.query result
|
||||
mock_query_result = [{"source": "alice", "relationship": "knows", "target": "bob"}]
|
||||
self.mock_graph.query.return_value = mock_query_result
|
||||
|
||||
# Call the _add_entities method
|
||||
result = self.memory_graph._add_entities(to_be_added, self.user_id, entity_type_map)
|
||||
|
||||
# Verify the method calls
|
||||
self.assertEqual(self.mock_embedding_model.embed.call_count, 2)
|
||||
self.memory_graph._search_source_node.assert_called_once_with(mock_embedding, self.user_id, threshold=0.9)
|
||||
self.memory_graph._search_destination_node.assert_called_once_with(mock_embedding, self.user_id, threshold=0.9)
|
||||
self.memory_graph._add_entities_cypher.assert_called_once()
|
||||
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
|
||||
|
||||
# Check the result
|
||||
self.assertEqual(result, [mock_query_result])
|
||||
|
||||
def test_delete_entities(self):
|
||||
"""Test the _delete_entities method."""
|
||||
# Mock data
|
||||
to_be_deleted = [{"source": "alice", "relationship": "knows", "destination": "bob"}]
|
||||
|
||||
# Mock the _delete_entities_cypher method
|
||||
mock_cypher = "MATCH (n) RETURN n"
|
||||
mock_params = {"source_name": "alice", "dest_name": "bob", "user_id": self.user_id}
|
||||
self.memory_graph._delete_entities_cypher = MagicMock(return_value=(mock_cypher, mock_params))
|
||||
|
||||
# Mock the graph.query result
|
||||
mock_query_result = [{"source": "alice", "relationship": "knows", "target": "bob"}]
|
||||
self.mock_graph.query.return_value = mock_query_result
|
||||
|
||||
# Call the _delete_entities method
|
||||
result = self.memory_graph._delete_entities(to_be_deleted, self.user_id)
|
||||
|
||||
# Verify the method calls
|
||||
self.memory_graph._delete_entities_cypher.assert_called_once_with("alice", "bob", "knows", self.user_id)
|
||||
self.mock_graph.query.assert_called_once_with(mock_cypher, params=mock_params)
|
||||
|
||||
# Check the result
|
||||
self.assertEqual(result, [mock_query_result])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user