Fix CI issues related to missing dependency (#3096)

This commit is contained in:
Deshraj Yadav
2025-07-03 18:52:50 -07:00
committed by GitHub
parent 2c496e6376
commit 7484eed4b2
32 changed files with 6150 additions and 828 deletions

View File

@@ -58,11 +58,10 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -e ".[test,graph,vector_stores,llms,extras]" pip install -e ".[test,graph,vector_stores,llms,extras]"
pip install ruff
if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true' if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true'
- name: Run Formatting - name: Run Linting
run: | run: make lint
mkdir -p .ruff_cache && chmod -R 777 .ruff_cache
hatch run format
- name: Run tests and generate coverage report - name: Run tests and generate coverage report
run: make test run: make test

View File

@@ -13,7 +13,7 @@ install:
install_all: install_all:
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \ pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \ google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow pymongo
# Format code with ruff # Format code with ruff
format: format:

View File

@@ -110,7 +110,7 @@ def main():
print("All categories accuracy:") print("All categories accuracy:")
for cat, results in LLM_JUDGE.items(): for cat, results in LLM_JUDGE.items():
if results: # Only print if there are results for this category if results: # Only print if there are results for this category
print(f" Category {cat}: {np.mean(results):.4f} " f"({sum(results)}/{len(results)})") print(f" Category {cat}: {np.mean(results):.4f} ({sum(results)}/{len(results)})")
print("------------------------------------------") print("------------------------------------------")
index += 1 index += 1

View File

@@ -68,7 +68,7 @@ class RAGManager:
def clean_chat_history(self, chat_history): def clean_chat_history(self, chat_history):
cleaned_chat_history = "" cleaned_chat_history = ""
for c in chat_history: for c in chat_history:
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n" cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: {c['text']}\n"
return cleaned_chat_history return cleaned_chat_history

View File

@@ -1,271 +1,267 @@
{ {
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "ApdaLD4Qi30H" "id": "ApdaLD4Qi30H"
}, },
"source": [ "source": [
"# Neo4j as Graph Memory" "# Neo4j as Graph Memory"
] ]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l7bi3i21i30I"
},
"source": [
"## Prerequisites\n",
"\n",
"### 1. Install Mem0 with Graph Memory support\n",
"\n",
"To use Mem0 with Graph Memory support, install it using pip:\n",
"\n",
"```bash\n",
"pip install \"mem0ai[graph]\"\n",
"```\n",
"\n",
"This command installs Mem0 along with the necessary dependencies for graph functionality.\n",
"\n",
"### 2. Install Neo4j\n",
"\n",
"To utilize Neo4j as Graph Memory, run it with Docker:\n",
"\n",
"```bash\n",
"docker run \\\n",
" -p 7474:7474 -p 7687:7687 \\\n",
" -e NEO4J_AUTH=neo4j/password \\\n",
" neo4j:5\n",
"```\n",
"\n",
"This command starts Neo4j with default credentials (`neo4j` / `password`) and exposes both the HTTP (7474) and Bolt (7687) ports.\n",
"\n",
"You can access the Neo4j browser at [http://localhost:7474](http://localhost:7474).\n",
"\n",
"Additional information can be found in the [Neo4j documentation](https://neo4j.com/docs/).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DkeBdFEpi30I"
},
"source": [
"## Configuration\n",
"\n",
"Do all the imports and configure OpenAI (enter your OpenAI API key):"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "d99EfBpii30I"
},
"outputs": [],
"source": [
"from mem0 import Memory\n",
"\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = (\n",
" \"\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QTucZJjIi30J"
},
"source": [
"Set up configuration to use the embedder model and Neo4j as a graph store:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "QSE0RFoSi30J"
},
"outputs": [],
"source": [
"config = {\n",
" \"embedder\": {\n",
" \"provider\": \"openai\",\n",
" \"config\": {\"model\": \"text-embedding-3-large\", \"embedding_dims\": 1536},\n",
" },\n",
" \"graph_store\": {\n",
" \"provider\": \"neo4j\",\n",
" \"config\": {\n",
" \"url\": \"bolt://54.87.227.131:7687\",\n",
" \"username\": \"neo4j\",\n",
" \"password\": \"causes-bins-vines\",\n",
" },\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OioTnv6xi30J"
},
"source": [
"## Graph Memory initializiation\n",
"\n",
"Initialize Neo4j as a Graph Memory store:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "fX-H9vgNi30J"
},
"outputs": [],
"source": [
"m = Memory.from_config(config_dict=config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kr1fVMwEi30J"
},
"source": [
"## Store memories\n",
"\n",
"Create memories:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "sEfogqp_i30J"
},
"outputs": [],
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"I'm planning to watch a movie tonight. Any recommendations?\",\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"How about a thriller movies? They can be quite engaging.\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"I'm not a big fan of thriller movies but I love sci-fi movies.\",\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
" },\n",
"]\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gtBHCyIgi30J"
},
"source": [
"Store memories in Neo4j:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "BMVGgZMFi30K"
},
"outputs": [],
"source": [
"# Store inferred memories (default behavior)\n",
"result = m.add(\n",
" messages, user_id=\"alice\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lQRptOywi30K"
},
"source": [
"![](https://github.com/tomasonjo/mem0/blob/neo4jexample/examples/graph-db-demo/alice-memories.png?raw=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LBXW7Gv-i30K"
},
"source": [
"## Search memories"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UHFDeQBEi30K",
"outputId": "2c69de7d-a79a-48f6-e3c4-bd743067857c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loves sci-fi movies 0.3153664287340898\n",
"Planning to watch a movie tonight 0.09683349296551162\n",
"Not a big fan of thriller movies 0.09468540071789466\n"
]
}
],
"source": [
"for result in m.search(\"what does alice love?\", user_id=\"alice\")[\"results\"]:\n",
" print(result[\"memory\"], result[\"score\"])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "2jXEIma9kK_Q"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.2"
}
}, },
"nbformat": 4, {
"nbformat_minor": 0 "cell_type": "markdown",
"metadata": {
"id": "l7bi3i21i30I"
},
"source": [
"## Prerequisites\n",
"\n",
"### 1. Install Mem0 with Graph Memory support\n",
"\n",
"To use Mem0 with Graph Memory support, install it using pip:\n",
"\n",
"```bash\n",
"pip install \"mem0ai[graph]\"\n",
"```\n",
"\n",
"This command installs Mem0 along with the necessary dependencies for graph functionality.\n",
"\n",
"### 2. Install Neo4j\n",
"\n",
"To utilize Neo4j as Graph Memory, run it with Docker:\n",
"\n",
"```bash\n",
"docker run \\\n",
" -p 7474:7474 -p 7687:7687 \\\n",
" -e NEO4J_AUTH=neo4j/password \\\n",
" neo4j:5\n",
"```\n",
"\n",
"This command starts Neo4j with default credentials (`neo4j` / `password`) and exposes both the HTTP (7474) and Bolt (7687) ports.\n",
"\n",
"You can access the Neo4j browser at [http://localhost:7474](http://localhost:7474).\n",
"\n",
"Additional information can be found in the [Neo4j documentation](https://neo4j.com/docs/).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DkeBdFEpi30I"
},
"source": [
"## Configuration\n",
"\n",
"Do all the imports and configure OpenAI (enter your OpenAI API key):"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "d99EfBpii30I"
},
"outputs": [],
"source": [
"from mem0 import Memory\n",
"\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QTucZJjIi30J"
},
"source": [
"Set up configuration to use the embedder model and Neo4j as a graph store:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "QSE0RFoSi30J"
},
"outputs": [],
"source": [
"config = {\n",
" \"embedder\": {\n",
" \"provider\": \"openai\",\n",
" \"config\": {\"model\": \"text-embedding-3-large\", \"embedding_dims\": 1536},\n",
" },\n",
" \"graph_store\": {\n",
" \"provider\": \"neo4j\",\n",
" \"config\": {\n",
" \"url\": \"bolt://54.87.227.131:7687\",\n",
" \"username\": \"neo4j\",\n",
" \"password\": \"causes-bins-vines\",\n",
" },\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OioTnv6xi30J"
},
"source": [
"## Graph Memory initializiation\n",
"\n",
"Initialize Neo4j as a Graph Memory store:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "fX-H9vgNi30J"
},
"outputs": [],
"source": [
"m = Memory.from_config(config_dict=config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kr1fVMwEi30J"
},
"source": [
"## Store memories\n",
"\n",
"Create memories:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "sEfogqp_i30J"
},
"outputs": [],
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"I'm planning to watch a movie tonight. Any recommendations?\",\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"How about a thriller movies? They can be quite engaging.\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"I'm not a big fan of thriller movies but I love sci-fi movies.\",\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
" },\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gtBHCyIgi30J"
},
"source": [
"Store memories in Neo4j:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "BMVGgZMFi30K"
},
"outputs": [],
"source": [
"# Store inferred memories (default behavior)\n",
"result = m.add(messages, user_id=\"alice\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lQRptOywi30K"
},
"source": [
"![](https://github.com/tomasonjo/mem0/blob/neo4jexample/examples/graph-db-demo/alice-memories.png?raw=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LBXW7Gv-i30K"
},
"source": [
"## Search memories"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UHFDeQBEi30K",
"outputId": "2c69de7d-a79a-48f6-e3c4-bd743067857c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loves sci-fi movies 0.3153664287340898\n",
"Planning to watch a movie tonight 0.09683349296551162\n",
"Not a big fan of thriller movies 0.09468540071789466\n"
]
}
],
"source": [
"for result in m.search(\"what does alice love?\", user_id=\"alice\")[\"results\"]:\n",
" print(result[\"memory\"], result[\"score\"])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "2jXEIma9kK_Q"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
} }

View File

@@ -45,11 +45,7 @@ def get_food_recommendation(user_query: str, user_id):
"""Get food recommendation with memory context""" """Get food recommendation with memory context"""
# Search memory for relevant food preferences # Search memory for relevant food preferences
memories_result = memory_client.search( memories_result = memory_client.search(query=user_query, user_id=user_id, limit=5)
query=user_query,
user_id=user_id,
limit=5
)
# Add memory context to the message # Add memory context to the message
memories = [f"- {result['memory']}" for result in memories_result] memories = [f"- {result['memory']}" for result in memories_result]
@@ -71,6 +67,7 @@ def get_food_recommendation(user_query: str, user_id):
# Save audio file # Save audio file
if response.audio: if response.audio:
import time import time
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"food_recommendation_{timestamp}.mp3" filename = f"food_recommendation_{timestamp}.mp3"
write_audio_to_file( write_audio_to_file(
@@ -118,7 +115,11 @@ def initialize_food_memory(user_id):
# Initialize the memory for the user once in order for the agent to learn the user preference # Initialize the memory for the user once in order for the agent to learn the user preference
initialize_food_memory(user_id=USER_ID) initialize_food_memory(user_id=USER_ID)
print(get_food_recommendation("Which type of restaurants should I go tonight for dinner and cuisines preferred?", user_id=USER_ID)) print(
get_food_recommendation(
"Which type of restaurants should I go tonight for dinner and cuisines preferred?", user_id=USER_ID
)
)
# OUTPUT: 🎵 Audio saved as food_recommendation_1750162610.mp3 # OUTPUT: 🎵 Audio saved as food_recommendation_1750162610.mp3
# For dinner tonight, considering your love for healthy spic optionsy, you could try a nice Thai, Indian, or Mexican restaurant. # For dinner tonight, considering your love for healthy spic optionsy, you could try a nice Thai, Indian, or Mexican restaurant.
# You might find dishes with quinoa, chickpeas, tofu, and fresh herbs delightful. Enjoy your dinner! # You might find dishes with quinoa, chickpeas, tofu, and fresh herbs delightful. Enjoy your dinner!

View File

@@ -1,4 +1,4 @@
from agents import Agent, Runner, function_tool, handoffs, enable_verbose_stdout_logging from agents import Agent, Runner, function_tool, enable_verbose_stdout_logging
from dotenv import load_dotenv from dotenv import load_dotenv
from mem0 import MemoryClient from mem0 import MemoryClient
@@ -35,7 +35,7 @@ travel_agent = Agent(
understand the user's travel preferences and history before making recommendations. understand the user's travel preferences and history before making recommendations.
After providing your response, use store_conversation to save important details.""", After providing your response, use store_conversation to save important details.""",
tools=[search_memory, save_memory], tools=[search_memory, save_memory],
model="gpt-4o" model="gpt-4o",
) )
health_agent = Agent( health_agent = Agent(
@@ -44,7 +44,7 @@ health_agent = Agent(
understand the user's health goals and dietary preferences. understand the user's health goals and dietary preferences.
After providing advice, use store_conversation to save relevant information.""", After providing advice, use store_conversation to save relevant information.""",
tools=[search_memory, save_memory], tools=[search_memory, save_memory],
model="gpt-4o" model="gpt-4o",
) )
# Triage agent with handoffs # Triage agent with handoffs
@@ -55,7 +55,7 @@ triage_agent = Agent(
For health-related questions (fitness, diet, wellness, exercise), hand off to Health Advisor. For health-related questions (fitness, diet, wellness, exercise), hand off to Health Advisor.
For general questions, you can handle them directly using available tools.""", For general questions, you can handle them directly using available tools.""",
handoffs=[travel_agent, health_agent], handoffs=[travel_agent, health_agent],
model="gpt-4o" model="gpt-4o",
) )
@@ -74,10 +74,7 @@ def chat_with_handoffs(user_input: str, user_id: str) -> str:
result = Runner.run_sync(triage_agent, user_input) result = Runner.run_sync(triage_agent, user_input)
# Store the original conversation in memory # Store the original conversation in memory
conversation = [ conversation = [{"role": "user", "content": user_input}, {"role": "assistant", "content": result.final_output}]
{"role": "user", "content": user_input},
{"role": "assistant", "content": result.final_output}
]
mem0.add(conversation, user_id=user_id) mem0.add(conversation, user_id=user_id)
return result.final_output return result.final_output

View File

@@ -34,96 +34,91 @@ config = {
"api_key": "vllm-api-key", "api_key": "vllm-api-key",
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 100, "max_tokens": 100,
} },
},
"embedder": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small"
}
}, },
"embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}},
"vector_store": { "vector_store": {
"provider": "qdrant", "provider": "qdrant",
"config": { "config": {"collection_name": "vllm_memories", "host": "localhost", "port": 6333},
"collection_name": "vllm_memories", },
"host": "localhost",
"port": 6333
}
}
} }
def main(): def main():
""" """
Demonstrate vLLM integration with mem0 Demonstrate vLLM integration with mem0
""" """
print("--> Initializing mem0 with vLLM...") print("--> Initializing mem0 with vLLM...")
# Initialize memory with vLLM # Initialize memory with vLLM
memory = Memory.from_config(config) memory = Memory.from_config(config)
print("--> Memory initialized successfully!") print("--> Memory initialized successfully!")
# Example conversations to store # Example conversations to store
conversations = [ conversations = [
{ {
"messages": [ "messages": [
{"role": "user", "content": "I love playing chess on weekends"}, {"role": "user", "content": "I love playing chess on weekends"},
{"role": "assistant", "content": "That's great! Chess is an excellent strategic game that helps improve critical thinking."} {
"role": "assistant",
"content": "That's great! Chess is an excellent strategic game that helps improve critical thinking.",
},
], ],
"user_id": "user_123" "user_id": "user_123",
}, },
{ {
"messages": [ "messages": [
{"role": "user", "content": "I'm learning Python programming"}, {"role": "user", "content": "I'm learning Python programming"},
{"role": "assistant", "content": "Python is a fantastic language for beginners! What specific areas are you focusing on?"} {
"role": "assistant",
"content": "Python is a fantastic language for beginners! What specific areas are you focusing on?",
},
], ],
"user_id": "user_123" "user_id": "user_123",
}, },
{ {
"messages": [ "messages": [
{"role": "user", "content": "I prefer working late at night, I'm more productive then"}, {"role": "user", "content": "I prefer working late at night, I'm more productive then"},
{"role": "assistant", "content": "Many people find they're more creative and focused during nighttime hours. It's important to maintain a consistent schedule that works for you."} {
"role": "assistant",
"content": "Many people find they're more creative and focused during nighttime hours. It's important to maintain a consistent schedule that works for you.",
},
], ],
"user_id": "user_123" "user_id": "user_123",
} },
] ]
print("\n--> Adding memories using vLLM...") print("\n--> Adding memories using vLLM...")
# Add memories - now powered by vLLM's high-performance inference # Add memories - now powered by vLLM's high-performance inference
for i, conversation in enumerate(conversations, 1): for i, conversation in enumerate(conversations, 1):
result = memory.add( result = memory.add(messages=conversation["messages"], user_id=conversation["user_id"])
messages=conversation["messages"],
user_id=conversation["user_id"]
)
print(f"Memory {i} added: {result}") print(f"Memory {i} added: {result}")
print("\n🔍 Searching memories...") print("\n🔍 Searching memories...")
# Search memories - vLLM will process the search and memory operations # Search memories - vLLM will process the search and memory operations
search_queries = [ search_queries = [
"What does the user like to do on weekends?", "What does the user like to do on weekends?",
"What is the user learning?", "What is the user learning?",
"When is the user most productive?" "When is the user most productive?",
] ]
for query in search_queries: for query in search_queries:
print(f"\nQuery: {query}") print(f"\nQuery: {query}")
memories = memory.search( memories = memory.search(query=query, user_id="user_123")
query=query,
user_id="user_123"
)
for memory_item in memories: for memory_item in memories:
print(f" - {memory_item['memory']}") print(f" - {memory_item['memory']}")
print("\n--> Getting all memories for user...") print("\n--> Getting all memories for user...")
all_memories = memory.get_all(user_id="user_123") all_memories = memory.get_all(user_id="user_123")
print(f"Total memories stored: {len(all_memories)}") print(f"Total memories stored: {len(all_memories)}")
for memory_item in all_memories: for memory_item in all_memories:
print(f" - {memory_item['memory']}") print(f" - {memory_item['memory']}")
print("\n--> vLLM integration demo completed successfully!") print("\n--> vLLM integration demo completed successfully!")
print("\nBenefits of using vLLM:") print("\nBenefits of using vLLM:")
print(" -> 2.7x higher throughput compared to standard implementations") print(" -> 2.7x higher throughput compared to standard implementations")

View File

@@ -89,9 +89,7 @@ class MemoryClient:
self.user_id = get_user_id() self.user_id = get_user_id()
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
"Mem0 API Key not provided. Please provide an API Key."
)
# Create MD5 hash of API key for user_id # Create MD5 hash of API key for user_id
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest() self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
@@ -174,9 +172,7 @@ class MemoryClient:
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
capture_client_event( capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"})
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -195,9 +191,7 @@ class MemoryClient:
params = self._prepare_params() params = self._prepare_params()
response = self.client.get(f"/v1/memories/{memory_id}/", params=params) response = self.client.get(f"/v1/memories/{memory_id}/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "sync"})
"client.get", self, {"memory_id": memory_id, "sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -224,13 +218,9 @@ class MemoryClient:
"page": params.pop("page"), "page": params.pop("page"),
"page_size": params.pop("page_size"), "page_size": params.pop("page_size"),
} }
response = self.client.post( response = self.client.post(f"/{version}/memories/", json=params, params=query_params)
f"/{version}/memories/", json=params, params=query_params
)
else: else:
response = self.client.post( response = self.client.post(f"/{version}/memories/", json=params)
f"/{version}/memories/", json=params
)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
@@ -246,9 +236,7 @@ class MemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
def search( def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
self, query: str, version: str = "v1", **kwargs
) -> List[Dict[str, Any]]:
"""Search memories based on a query. """Search memories based on a query.
Args: Args:
@@ -266,9 +254,7 @@ class MemoryClient:
payload = {"query": query} payload = {"query": query}
params = self._prepare_params(kwargs) params = self._prepare_params(kwargs)
payload.update(params) payload.update(params)
response = self.client.post( response = self.client.post(f"/{version}/memories/search/", json=payload)
f"/{version}/memories/search/", json=payload
)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
@@ -308,13 +294,9 @@ class MemoryClient:
if metadata is not None: if metadata is not None:
payload["metadata"] = metadata payload["metadata"] = metadata
capture_client_event( capture_client_event("client.update", self, {"memory_id": memory_id, "sync_type": "sync"})
"client.update", self, {"memory_id": memory_id, "sync_type": "sync"}
)
params = self._prepare_params() params = self._prepare_params()
response = self.client.put( response = self.client.put(f"/v1/memories/{memory_id}/", json=payload, params=params)
f"/v1/memories/{memory_id}/", json=payload, params=params
)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -332,13 +314,9 @@ class MemoryClient:
APIError: If the API request fails. APIError: If the API request fails.
""" """
params = self._prepare_params() params = self._prepare_params()
response = self.client.delete( response = self.client.delete(f"/v1/memories/{memory_id}/", params=params)
f"/v1/memories/{memory_id}/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.delete", self, {"memory_id": memory_id, "sync_type": "sync"})
"client.delete", self, {"memory_id": memory_id, "sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -379,13 +357,9 @@ class MemoryClient:
APIError: If the API request fails. APIError: If the API request fails.
""" """
params = self._prepare_params() params = self._prepare_params()
response = self.client.get( response = self.client.get(f"/v1/memories/{memory_id}/history/", params=params)
f"/v1/memories/{memory_id}/history/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.history", self, {"memory_id": memory_id, "sync_type": "sync"})
"client.history", self, {"memory_id": memory_id, "sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -432,10 +406,7 @@ class MemoryClient:
else: else:
entities = self.users() entities = self.users()
# Filter entities based on provided IDs using list comprehension # Filter entities based on provided IDs using list comprehension
to_delete = [ to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
{"type": entity["type"], "name": entity["name"]}
for entity in entities["results"]
]
params = self._prepare_params() params = self._prepare_params()
@@ -444,9 +415,7 @@ class MemoryClient:
# Delete entities and check response immediately # Delete entities and check response immediately
for entity in to_delete: for entity in to_delete:
response = self.client.delete( response = self.client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event(
@@ -484,9 +453,7 @@ class MemoryClient:
self.delete_users() self.delete_users()
capture_client_event("client.reset", self, {"sync_type": "sync"}) capture_client_event("client.reset", self, {"sync_type": "sync"})
return { return {"message": "Client reset successful. All users and memories deleted."}
"message": "Client reset successful. All users and memories deleted."
}
@api_error_handler @api_error_handler
def batch_update(self, memories: List[Dict[str, Any]]) -> Dict[str, Any]: def batch_update(self, memories: List[Dict[str, Any]]) -> Dict[str, Any]:
@@ -507,9 +474,7 @@ class MemoryClient:
response = self.client.put("/v1/batch/", json={"memories": memories}) response = self.client.put("/v1/batch/", json={"memories": memories})
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.batch_update", self, {"sync_type": "sync"})
"client.batch_update", self, {"sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -527,14 +492,10 @@ class MemoryClient:
Raises: Raises:
APIError: If the API request fails. APIError: If the API request fails.
""" """
response = self.client.request( response = self.client.request("DELETE", "/v1/batch/", json={"memories": memories})
"DELETE", "/v1/batch/", json={"memories": memories}
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.batch_delete", self, {"sync_type": "sync"})
"client.batch_delete", self, {"sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -574,9 +535,7 @@ class MemoryClient:
Returns: Returns:
Dict containing the exported data Dict containing the exported data
""" """
response = self.client.post( response = self.client.post("/v1/exports/get/", json=self._prepare_params(kwargs))
"/v1/exports/get/", json=self._prepare_params(kwargs)
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event(
"client.get_memory_export", "client.get_memory_export",
@@ -586,9 +545,7 @@ class MemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
def get_summary( def get_summary(self, filters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
self, filters: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Get the summary of a memory export. """Get the summary of a memory export.
Args: Args:
@@ -598,17 +555,13 @@ class MemoryClient:
Dict containing the export status and summary data Dict containing the export status and summary data
""" """
response = self.client.post( response = self.client.post("/v1/summary/", json=self._prepare_params({"filters": filters}))
"/v1/summary/", json=self._prepare_params({"filters": filters})
)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.get_summary", self, {"sync_type": "sync"}) capture_client_event("client.get_summary", self, {"sync_type": "sync"})
return response.json() return response.json()
@api_error_handler @api_error_handler
def get_project( def get_project(self, fields: Optional[List[str]] = None) -> Dict[str, Any]:
self, fields: Optional[List[str]] = None
) -> Dict[str, Any]:
"""Get instructions or categories for the current project. """Get instructions or categories for the current project.
Args: Args:
@@ -622,10 +575,7 @@ class MemoryClient:
ValueError: If org_id or project_id are not set. ValueError: If org_id or project_id are not set.
""" """
if not (self.org_id and self.project_id): if not (self.org_id and self.project_id):
raise ValueError( raise ValueError("org_id and project_id must be set to access instructions or categories")
"org_id and project_id must be set to access instructions or "
"categories"
)
params = self._prepare_params({"fields": fields}) params = self._prepare_params({"fields": fields})
response = self.client.get( response = self.client.get(
@@ -666,10 +616,7 @@ class MemoryClient:
ValueError: If org_id or project_id are not set. ValueError: If org_id or project_id are not set.
""" """
if not (self.org_id and self.project_id): if not (self.org_id and self.project_id):
raise ValueError( raise ValueError("org_id and project_id must be set to update instructions or categories")
"org_id and project_id must be set to update instructions or "
"categories"
)
if ( if (
custom_instructions is None custom_instructions is None
@@ -826,10 +773,7 @@ class MemoryClient:
feedback = feedback.upper() if feedback else None feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES: if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError( raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} '
"or None"
)
data = { data = {
"memory_id": memory_id, "memory_id": memory_id,
@@ -839,14 +783,10 @@ class MemoryClient:
response = self.client.post("/v1/feedback/", json=data) response = self.client.post("/v1/feedback/", json=data)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.feedback", self, data, {"sync_type": "sync"})
"client.feedback", self, data, {"sync_type": "sync"}
)
return response.json() return response.json()
def _prepare_payload( def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Prepare the payload for API requests. """Prepare the payload for API requests.
Args: Args:
@@ -862,9 +802,7 @@ class MemoryClient:
payload.update({k: v for k, v in kwargs.items() if v is not None}) payload.update({k: v for k, v in kwargs.items() if v is not None})
return payload return payload
def _prepare_params( def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Prepare query parameters for API requests. """Prepare query parameters for API requests.
Args: Args:
@@ -929,9 +867,7 @@ class AsyncMemoryClient:
self.user_id = get_user_id() self.user_id = get_user_id()
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
"Mem0 API Key not provided. Please provide an API Key."
)
# Create MD5 hash of API key for user_id # Create MD5 hash of API key for user_id
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest() self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
@@ -989,9 +925,7 @@ class AsyncMemoryClient:
error_message = str(e) error_message = str(e)
raise ValueError(f"Error: {error_message}") raise ValueError(f"Error: {error_message}")
def _prepare_payload( def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Prepare the payload for API requests. """Prepare the payload for API requests.
Args: Args:
@@ -1007,9 +941,7 @@ class AsyncMemoryClient:
payload.update({k: v for k, v in kwargs.items() if v is not None}) payload.update({k: v for k, v in kwargs.items() if v is not None})
return payload return payload
def _prepare_params( def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Prepare query parameters for API requests. """Prepare query parameters for API requests.
Args: Args:
@@ -1041,9 +973,7 @@ class AsyncMemoryClient:
await self.async_client.aclose() await self.async_client.aclose()
@api_error_handler @api_error_handler
async def add( async def add(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
self, messages: List[Dict[str, str]], **kwargs
) -> Dict[str, Any]:
kwargs = self._prepare_params(kwargs) kwargs = self._prepare_params(kwargs)
if kwargs.get("output_format") != "v1.1": if kwargs.get("output_format") != "v1.1":
kwargs["output_format"] = "v1.1" kwargs["output_format"] = "v1.1"
@@ -1062,45 +992,31 @@ class AsyncMemoryClient:
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
capture_client_event( capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"})
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
async def get(self, memory_id: str) -> Dict[str, Any]: async def get(self, memory_id: str) -> Dict[str, Any]:
params = self._prepare_params() params = self._prepare_params()
response = await self.async_client.get( response = await self.async_client.get(f"/v1/memories/{memory_id}/", params=params)
f"/v1/memories/{memory_id}/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "async"})
"client.get", self, {"memory_id": memory_id, "sync_type": "async"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
async def get_all( async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
self, version: str = "v1", **kwargs
) -> List[Dict[str, Any]]:
params = self._prepare_params(kwargs) params = self._prepare_params(kwargs)
if version == "v1": if version == "v1":
response = await self.async_client.get( response = await self.async_client.get(f"/{version}/memories/", params=params)
f"/{version}/memories/", params=params
)
elif version == "v2": elif version == "v2":
if "page" in params and "page_size" in params: if "page" in params and "page_size" in params:
query_params = { query_params = {
"page": params.pop("page"), "page": params.pop("page"),
"page_size": params.pop("page_size"), "page_size": params.pop("page_size"),
} }
response = await self.async_client.post( response = await self.async_client.post(f"/{version}/memories/", json=params, params=query_params)
f"/{version}/memories/", json=params, params=query_params
)
else: else:
response = await self.async_client.post( response = await self.async_client.post(f"/{version}/memories/", json=params)
f"/{version}/memories/", json=params
)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
@@ -1116,14 +1032,10 @@ class AsyncMemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
async def search( async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
self, query: str, version: str = "v1", **kwargs
) -> List[Dict[str, Any]]:
payload = {"query": query} payload = {"query": query}
payload.update(self._prepare_params(kwargs)) payload.update(self._prepare_params(kwargs))
response = await self.async_client.post( response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
f"/{version}/memories/search/", json=payload
)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
@@ -1139,7 +1051,9 @@ class AsyncMemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
async def update(self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: async def update(
self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
""" """
Update a memory by ID. Update a memory by ID.
Args: Args:
@@ -1265,10 +1179,7 @@ class AsyncMemoryClient:
else: else:
entities = await self.users() entities = await self.users()
# Filter entities based on provided IDs using list comprehension # Filter entities based on provided IDs using list comprehension
to_delete = [ to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
{"type": entity["type"], "name": entity["name"]}
for entity in entities["results"]
]
params = self._prepare_params() params = self._prepare_params()
@@ -1277,9 +1188,7 @@ class AsyncMemoryClient:
# Delete entities and check response immediately # Delete entities and check response immediately
for entity in to_delete: for entity in to_delete:
response = await self.async_client.delete( response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event(
@@ -1335,9 +1244,7 @@ class AsyncMemoryClient:
response = await self.async_client.put("/v1/batch/", json={"memories": memories}) response = await self.async_client.put("/v1/batch/", json={"memories": memories})
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.batch_update", self, {"sync_type": "async"})
"client.batch_update", self, {"sync_type": "async"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -1355,14 +1262,10 @@ class AsyncMemoryClient:
Raises: Raises:
APIError: If the API request fails. APIError: If the API request fails.
""" """
response = await self.async_client.request( response = await self.async_client.request("DELETE", "/v1/batch/", json={"memories": memories})
"DELETE", "/v1/batch/", json={"memories": memories}
)
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event("client.batch_delete", self, {"sync_type": "async"})
"client.batch_delete", self, {"sync_type": "async"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -1614,7 +1517,7 @@ class AsyncMemoryClient:
feedback = feedback.upper() if feedback else None feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES: if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None') raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason} data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}

View File

@@ -1,4 +1,3 @@
from enum import Enum
from typing import Any, Dict from typing import Any, Dict
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator

View File

@@ -11,7 +11,7 @@ class MongoDBConfig(BaseModel):
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors") embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017") mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
@model_validator(mode='before') @model_validator(mode="before")
@classmethod @classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys()) allowed_fields = set(cls.model_fields.keys())

View File

@@ -36,6 +36,6 @@ class OpenSearchConfig(BaseModel):
extra_fields = input_fields - allowed_fields extra_fields = input_fields - allowed_fields
if extra_fields: if extra_fields:
raise ValueError( raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}" f"Extra fields not allowed: {', '.join(extra_fields)}. Allowed fields: {', '.join(allowed_fields)}"
) )
return values return values

View File

@@ -36,4 +36,4 @@ class GoogleGenAIEmbedding(EmbeddingBase):
# Call the embed_content method with the correct parameters # Call the embed_content method with the correct parameters
response = self.client.models.embed_content(model=self.config.model, contents=text, config=config) response = self.client.models.embed_content(model=self.config.model, contents=text, config=config)
return response.embeddings[0].values return response.embeddings[0].values

View File

@@ -92,10 +92,12 @@ class AWSBedrockLLM(LLMBase):
if response["output"]["message"]["content"]: if response["output"]["message"]["content"]:
for item in response["output"]["message"]["content"]: for item in response["output"]["message"]["content"]:
if "toolUse" in item: if "toolUse" in item:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
"name": item["toolUse"]["name"], {
"arguments": item["toolUse"]["input"], "name": item["toolUse"]["name"],
}) "arguments": item["toolUse"]["input"],
}
)
return processed_response return processed_response

View File

@@ -165,7 +165,6 @@ class GeminiLLM(LLMBase):
if system_instruction: if system_instruction:
config_params["system_instruction"] = system_instruction config_params["system_instruction"] = system_instruction
if response_format is not None and response_format["type"] == "json_object": if response_format is not None and response_format["type"] == "json_object":
config_params["response_mime_type"] = "application/json" config_params["response_mime_type"] = "application/json"
if "schema" in response_format: if "schema" in response_format:
@@ -175,7 +174,6 @@ class GeminiLLM(LLMBase):
formatted_tools = self._reformat_tools(tools) formatted_tools = self._reformat_tools(tools)
config_params["tools"] = formatted_tools config_params["tools"] = formatted_tools
if tool_choice: if tool_choice:
if tool_choice == "auto": if tool_choice == "auto":
mode = types.FunctionCallingConfigMode.AUTO mode = types.FunctionCallingConfigMode.AUTO

View File

@@ -18,7 +18,7 @@ class SarvamLLM(LLMBase):
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError(
"Sarvam API key is required. Set SARVAM_API_KEY environment variable " "or provide api_key in config." "Sarvam API key is required. Set SARVAM_API_KEY environment variable or provide api_key in config."
) )
# Set base URL - use config value or environment or default # Set base URL - use config value or environment or default

View File

@@ -7,7 +7,6 @@ from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json from mem0.memory.utils import extract_json
from openai import OpenAI
class VllmLLM(LLMBase): class VllmLLM(LLMBase):
@@ -41,10 +40,12 @@ class VllmLLM(LLMBase):
if response.choices[0].message.tool_calls: if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls: for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
"name": tool_call.function.name, {
"arguments": json.loads(extract_json(tool_call.function.arguments)), "name": tool_call.function.name,
}) "arguments": json.loads(extract_json(tool_call.function.arguments)),
}
)
return processed_response return processed_response
else: else:

View File

@@ -136,7 +136,6 @@ class MemoryGraph:
params = {"user_id": filters["user_id"]} params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params) self.graph.query(cypher, params=params)
def get_all(self, filters, limit=100): def get_all(self, filters, limit=100):
""" """
Retrieves all nodes and relationships from the graph database based on optional filtering criteria. Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
@@ -176,7 +175,6 @@ class MemoryGraph:
return final_results return final_results
def _retrieve_nodes_from_data(self, data, filters): def _retrieve_nodes_from_data(self, data, filters):
"""Extracts all the entities mentioned in the query.""" """Extracts all the entities mentioned in the query."""
_tools = [EXTRACT_ENTITIES_TOOL] _tools = [EXTRACT_ENTITIES_TOOL]
@@ -213,7 +211,7 @@ class MemoryGraph:
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map): def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
"""Establish relations among the extracted nodes.""" """Establish relations among the extracted nodes."""
# Compose user identification string for prompt # Compose user identification string for prompt
user_identity = f"user_id: {filters['user_id']}" user_identity = f"user_id: {filters['user_id']}"
if filters.get("agent_id"): if filters.get("agent_id"):
user_identity += f", agent_id: {filters['agent_id']}" user_identity += f", agent_id: {filters['agent_id']}"
@@ -221,9 +219,7 @@ class MemoryGraph:
if self.config.graph_store.custom_prompt: if self.config.graph_store.custom_prompt:
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity) system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
# Add the custom prompt line if configured # Add the custom prompt line if configured
system_content = system_content.replace( system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
)
messages = [ messages = [
{"role": "system", "content": system_content}, {"role": "system", "content": system_content},
{"role": "user", "content": data}, {"role": "user", "content": data},
@@ -336,7 +332,7 @@ class MemoryGraph:
user_id = filters["user_id"] user_id = filters["user_id"]
agent_id = filters.get("agent_id", None) agent_id = filters.get("agent_id", None)
results = [] results = []
for item in to_be_deleted: for item in to_be_deleted:
source = item["source"] source = item["source"]
destination = item["destination"] destination = item["destination"]
@@ -349,7 +345,7 @@ class MemoryGraph:
"dest_name": destination, "dest_name": destination,
"user_id": user_id, "user_id": user_id,
} }
if agent_id: if agent_id:
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id" agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
params["agent_id"] = agent_id params["agent_id"] = agent_id
@@ -366,10 +362,10 @@ class MemoryGraph:
m.name AS target, m.name AS target,
type(r) AS relationship type(r) AS relationship
""" """
result = self.graph.query(cypher, params=params) result = self.graph.query(cypher, params=params)
results.append(result) results.append(result)
return results return results
def _add_entities(self, to_be_added, filters, entity_type_map): def _add_entities(self, to_be_added, filters, entity_type_map):
@@ -430,7 +426,7 @@ class MemoryGraph:
r.mentions = coalesce(r.mentions, 0) + 1 r.mentions = coalesce(r.mentions, 0) + 1
RETURN source.name AS source, type(r) AS relationship, destination.name AS target RETURN source.name AS source, type(r) AS relationship, destination.name AS target
""" """
params = { params = {
"source_id": source_node_search_result[0]["elementId(source_candidate)"], "source_id": source_node_search_result[0]["elementId(source_candidate)"],
"destination_name": destination, "destination_name": destination,
@@ -592,7 +588,6 @@ class MemoryGraph:
result = self.graph.query(cypher, params=params) result = self.graph.query(cypher, params=params)
return result return result
def _search_destination_node(self, destination_embedding, filters, threshold=0.9): def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
agent_filter = "" agent_filter = ""
if filters.get("agent_id"): if filters.get("agent_id"):

View File

@@ -338,7 +338,7 @@ class Memory(MemoryBase):
except Exception as e: except Exception as e:
logger.error(f"Error in new_retrieved_facts: {e}") logger.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = [] new_retrieved_facts = []
if not new_retrieved_facts: if not new_retrieved_facts:
logger.debug("No new facts retrieved from input. Skipping memory update LLM call.") logger.debug("No new facts retrieved from input. Skipping memory update LLM call.")
@@ -1166,7 +1166,7 @@ class AsyncMemory(MemoryBase):
except Exception as e: except Exception as e:
logger.error(f"Error in new_retrieved_facts: {e}") logger.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = [] new_retrieved_facts = []
if not new_retrieved_facts: if not new_retrieved_facts:
logger.debug("No new facts retrieved from input. Skipping memory update LLM call.") logger.debug("No new facts retrieved from input. Skipping memory update LLM call.")

View File

@@ -162,7 +162,7 @@ class MemoryGraph:
LIMIT $limit LIMIT $limit
""" """
params = {"user_id": filters["user_id"], "limit": limit} params = {"user_id": filters["user_id"], "limit": limit}
results = self.graph.query(query, params=params) results = self.graph.query(query, params=params)
final_results = [] final_results = []
@@ -318,7 +318,7 @@ class MemoryGraph:
"user_id": filters["user_id"], "user_id": filters["user_id"],
"limit": limit, "limit": limit,
} }
ans = self.graph.query(cypher_query, params=params) ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans) result_relations.extend(ans)
@@ -356,7 +356,7 @@ class MemoryGraph:
user_id = filters["user_id"] user_id = filters["user_id"]
agent_id = filters.get("agent_id", None) agent_id = filters.get("agent_id", None)
results = [] results = []
for item in to_be_deleted: for item in to_be_deleted:
source = item["source"] source = item["source"]
destination = item["destination"] destination = item["destination"]
@@ -369,7 +369,7 @@ class MemoryGraph:
"dest_name": destination, "dest_name": destination,
"user_id": user_id, "user_id": user_id,
} }
if agent_id: if agent_id:
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id" agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
params["agent_id"] = agent_id params["agent_id"] = agent_id
@@ -386,10 +386,10 @@ class MemoryGraph:
m.name AS target, m.name AS target,
type(r) AS relationship type(r) AS relationship
""" """
result = self.graph.query(cypher, params=params) result = self.graph.query(cypher, params=params)
results.append(result) results.append(result)
return results return results
# added Entity label to all nodes for vector search to work # added Entity label to all nodes for vector search to work
@@ -398,7 +398,7 @@ class MemoryGraph:
user_id = filters["user_id"] user_id = filters["user_id"]
agent_id = filters.get("agent_id", None) agent_id = filters.get("agent_id", None)
results = [] results = []
for item in to_be_added: for item in to_be_added:
# entities # entities
source = item["source"] source = item["source"]
@@ -421,7 +421,7 @@ class MemoryGraph:
agent_id_clause = "" agent_id_clause = ""
if agent_id: if agent_id:
agent_id_clause = ", agent_id: $agent_id" agent_id_clause = ", agent_id: $agent_id"
# TODO: Create a cypher query and common params for all the cases # TODO: Create a cypher query and common params for all the cases
if not destination_node_search_result and source_node_search_result: if not destination_node_search_result and source_node_search_result:
cypher = f""" cypher = f"""
@@ -446,7 +446,7 @@ class MemoryGraph:
} }
if agent_id: if agent_id:
params["agent_id"] = agent_id params["agent_id"] = agent_id
elif destination_node_search_result and not source_node_search_result: elif destination_node_search_result and not source_node_search_result:
cypher = f""" cypher = f"""
MATCH (destination:Entity) MATCH (destination:Entity)
@@ -470,7 +470,7 @@ class MemoryGraph:
} }
if agent_id: if agent_id:
params["agent_id"] = agent_id params["agent_id"] = agent_id
elif source_node_search_result and destination_node_search_result: elif source_node_search_result and destination_node_search_result:
cypher = f""" cypher = f"""
MATCH (source:Entity) MATCH (source:Entity)
@@ -490,7 +490,7 @@ class MemoryGraph:
} }
if agent_id: if agent_id:
params["agent_id"] = agent_id params["agent_id"] = agent_id
else: else:
cypher = f""" cypher = f"""
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}}) MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
@@ -512,7 +512,7 @@ class MemoryGraph:
} }
if agent_id: if agent_id:
params["agent_id"] = agent_id params["agent_id"] = agent_id
result = self.graph.query(cypher, params=params) result = self.graph.query(cypher, params=params)
results.append(result) results.append(result)
return results return results
@@ -528,7 +528,7 @@ class MemoryGraph:
"""Search for source nodes with similar embeddings.""" """Search for source nodes with similar embeddings."""
user_id = filters["user_id"] user_id = filters["user_id"]
agent_id = filters.get("agent_id", None) agent_id = filters.get("agent_id", None)
if agent_id: if agent_id:
cypher = """ cypher = """
CALL vector_search.search("memzero", 1, $source_embedding) CALL vector_search.search("memzero", 1, $source_embedding)
@@ -567,7 +567,7 @@ class MemoryGraph:
"""Search for destination nodes with similar embeddings.""" """Search for destination nodes with similar embeddings."""
user_id = filters["user_id"] user_id = filters["user_id"]
agent_id = filters.get("agent_id", None) agent_id = filters.get("agent_id", None)
if agent_id: if agent_id:
cypher = """ cypher = """
CALL vector_search.search("memzero", 1, $destination_embedding) CALL vector_search.search("memzero", 1, $destination_embedding)

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import List, Optional, Dict, Any, Callable from typing import List, Optional, Dict, Any
from pydantic import BaseModel from pydantic import BaseModel
@@ -26,13 +26,7 @@ class MongoDB(VectorStoreBase):
VECTOR_TYPE = "knnVector" VECTOR_TYPE = "knnVector"
SIMILARITY_METRIC = "cosine" SIMILARITY_METRIC = "cosine"
def __init__( def __init__(self, db_name: str, collection_name: str, embedding_model_dims: int, mongo_uri: str):
self,
db_name: str,
collection_name: str,
embedding_model_dims: int,
mongo_uri: str
):
""" """
Initialize the MongoDB vector store with vector search capabilities. Initialize the MongoDB vector store with vector search capabilities.
@@ -46,9 +40,7 @@ class MongoDB(VectorStoreBase):
self.embedding_model_dims = embedding_model_dims self.embedding_model_dims = embedding_model_dims
self.db_name = db_name self.db_name = db_name
self.client = MongoClient( self.client = MongoClient(mongo_uri)
mongo_uri
)
self.db = self.client[db_name] self.db = self.client[db_name]
self.collection = self.create_col() self.collection = self.create_col()
@@ -119,7 +111,9 @@ class MongoDB(VectorStoreBase):
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error inserting data: {e}") logger.error(f"Error inserting data: {e}")
def search(self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]: def search(
self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None
) -> List[OutputData]:
""" """
Search for similar vectors using the vector search index. Search for similar vectors using the vector search index.
@@ -285,7 +279,7 @@ class MongoDB(VectorStoreBase):
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error listing documents: {e}") logger.error(f"Error listing documents: {e}")
return [] return []
def reset(self): def reset(self):
"""Reset the index by deleting and recreating it.""" """Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...") logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -88,7 +88,7 @@ class OpenSearchDB(VectorStoreBase):
self.client.indices.create(index=name, body=index_settings) self.client.indices.create(index=name, body=index_settings)
# Wait for index to be ready # Wait for index to be ready
max_retries = 180 # 3 minutes timeout max_retries = 180 # 3 minutes timeout
retry_count = 0 retry_count = 0
while retry_count < max_retries: while retry_count < max_retries:
try: try:

5797
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -30,11 +30,13 @@ vector_stores = [
"vecs>=0.4.0", "vecs>=0.4.0",
"chromadb>=0.4.24", "chromadb>=0.4.24",
"weaviate-client>=4.4.0", "weaviate-client>=4.4.0",
"pinecone<7.0.0", "pinecone<=7.3.0",
"pinecone-text>=0.1.1", "pinecone-text>=0.10.0",
"faiss-cpu>=1.7.4", "faiss-cpu>=1.7.4",
"upstash-vector>=0.1.0", "upstash-vector>=0.1.0",
"azure-search-documents>=11.4.0b8", "azure-search-documents>=11.4.0b8",
"pymongo>=4.13.2",
"pymochow>=2.2.9",
] ]
llms = [ llms = [
"groq>=0.3.0", "groq>=0.3.0",
@@ -44,12 +46,11 @@ llms = [
"vertexai>=0.1.0", "vertexai>=0.1.0",
"google-generativeai>=0.3.0", "google-generativeai>=0.3.0",
"google-genai>=1.0.0", "google-genai>=1.0.0",
] ]
extras = [ extras = [
"boto3>=1.34.0", "boto3>=1.34.0",
"langchain-community>=0.0.0", "langchain-community>=0.0.0",
"sentence-transformers>=2.2.2", "sentence-transformers>=5.0.0",
"elasticsearch>=8.0.0", "elasticsearch>=8.0.0",
"opensearch-py>=2.0.0", "opensearch-py>=2.0.0",
"langchain-memgraph>=0.1.0", "langchain-memgraph>=0.1.0",

View File

@@ -1,4 +1,4 @@
from unittest.mock import patch from unittest.mock import patch, ANY
import pytest import pytest
@@ -8,8 +8,10 @@ from mem0.embeddings.gemini import GoogleGenAIEmbedding
@pytest.fixture @pytest.fixture
def mock_genai(): def mock_genai():
with patch("mem0.embeddings.gemini.genai.embed_content") as mock_genai: with patch("mem0.embeddings.gemini.genai.Client") as mock_client_class:
yield mock_genai mock_client = mock_client_class.return_value
mock_client.models.embed_content.return_value = None
yield mock_client.models.embed_content
@pytest.fixture @pytest.fixture
@@ -18,7 +20,9 @@ def config():
def test_embed_query(mock_genai, config): def test_embed_query(mock_genai, config):
mock_embedding_response = {"embedding": [0.1, 0.2, 0.3, 0.4]} mock_embedding_response = type('Response', (), {
'embeddings': [type('Embedding', (), {'values': [0.1, 0.2, 0.3, 0.4]})]
})()
mock_genai.return_value = mock_embedding_response mock_genai.return_value = mock_embedding_response
embedder = GoogleGenAIEmbedding(config) embedder = GoogleGenAIEmbedding(config)
@@ -27,10 +31,11 @@ def test_embed_query(mock_genai, config):
embedding = embedder.embed(text) embedding = embedder.embed(text)
assert embedding == [0.1, 0.2, 0.3, 0.4] assert embedding == [0.1, 0.2, 0.3, 0.4]
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786) mock_genai.assert_called_once_with(model="test_model", contents="Hello, world!", config=ANY)
def test_embed_returns_empty_list_if_none(mock_genai, config): def test_embed_returns_empty_list_if_none(mock_genai, config):
mock_genai.return_value = None mock_genai.return_value = type('Response', (), {'embeddings': [type('Embedding', (), {'values': []})]})()
embedder = GoogleGenAIEmbedding(config) embedder = GoogleGenAIEmbedding(config)
result = embedder.embed("test") result = embedder.embed("test")
@@ -47,10 +52,10 @@ def test_embed_raises_on_error(mock_genai, config):
with pytest.raises(RuntimeError, match="Embedding failed"): with pytest.raises(RuntimeError, match="Embedding failed"):
embedder.embed("some input") embedder.embed("some input")
def test_config_initialization(config): def test_config_initialization(config):
embedder = GoogleGenAIEmbedding(config) embedder = GoogleGenAIEmbedding(config)
assert embedder.config.api_key == "dummy_api_key" assert embedder.config.api_key == "dummy_api_key"
assert embedder.config.model == "test_model" assert embedder.config.model == "test_model"
assert embedder.config.embedding_dims == 786 assert embedder.config.embedding_dims == 786

View File

@@ -9,7 +9,7 @@ from mem0.llms.gemini import GeminiLLM
@pytest.fixture @pytest.fixture
def mock_gemini_client(): def mock_gemini_client():
with patch("mem0.llms.gemini.genai") as mock_client_class: with patch("mem0.llms.gemini.genai.Client") as mock_client_class:
mock_client = Mock() mock_client = Mock()
mock_client_class.return_value = mock_client mock_client_class.return_value = mock_client
yield mock_client yield mock_client
@@ -24,43 +24,30 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
] ]
mock_part = Mock(text="I'm doing well, thank you for asking!") mock_part = Mock(text="I'm doing well, thank you for asking!")
mock_embedding = Mock() mock_content = Mock(parts=[mock_part])
mock_embedding.values = [0.1, 0.2, 0.3] mock_candidate = Mock(content=mock_content)
mock_response = Mock(candidates=[mock_candidate])
mock_response = Mock()
mock_response.candidates = [Mock()]
mock_response.candidates[0].content.parts = [Mock()]
mock_response.candidates[0].content.parts[0].text = "I'm doing well, thank you for asking!"
mock_gemini_client.models.generate_content.return_value = mock_response mock_gemini_client.models.generate_content.return_value = mock_response
mock_content = Mock(parts=[mock_part])
mock_message = Mock(content=mock_content)
mock_response = Mock(candidates=[mock_message])
mock_gemini_client.generate_content.return_value = mock_response
response = llm.generate_response(messages) response = llm.generate_response(messages)
mock_gemini_client.generate_content.assert_called_once_with( # Check the actual call - system instruction is now in config
contents=[ mock_gemini_client.models.generate_content.assert_called_once()
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, call_args = mock_gemini_client.models.generate_content.call_args
{"parts": "Hello, how are you?", "role": "user"},
],
config=types.GenerateContentConfig(
temperature=0.7,
max_output_tokens=100,
top_p=1.0,
tools=None,
tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
allowed_function_names=None,
mode="auto"
)
)
) )
assert response == "I'm doing well, thank you for asking!"
# Verify model and contents
assert call_args.kwargs['model'] == "gemini-2.0-flash-latest"
assert len(call_args.kwargs['contents']) == 1 # Only user message
# Verify config has system instruction
config_arg = call_args.kwargs['config']
assert config_arg.system_instruction == "You are a helpful assistant."
assert config_arg.temperature == 0.7
assert config_arg.max_output_tokens == 100
assert config_arg.top_p == 1.0
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_gemini_client: Mock): def test_generate_response_with_tools(mock_gemini_client: Mock):
@@ -89,64 +76,42 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
mock_tool_call.name = "add_memory" mock_tool_call.name = "add_memory"
mock_tool_call.args = {"data": "Today is a sunny day."} mock_tool_call.args = {"data": "Today is a sunny day."}
mock_part = Mock() # Create mock parts with both text and function_call
mock_part.function_call = mock_tool_call mock_text_part = Mock()
mock_part.text = "I've added the memory for you." mock_text_part.text = "I've added the memory for you."
mock_text_part.function_call = None
mock_func_part = Mock()
mock_func_part.text = None
mock_func_part.function_call = mock_tool_call
mock_content = Mock() mock_content = Mock()
mock_content.parts = [mock_part] mock_content.parts = [mock_text_part, mock_func_part]
mock_message = Mock() mock_candidate = Mock()
mock_message.content = mock_content mock_candidate.content = mock_content
mock_response = Mock(candidates=[mock_message]) mock_response = Mock(candidates=[mock_candidate])
mock_gemini_client.generate_content.return_value = mock_response mock_gemini_client.models.generate_content.return_value = mock_response
response = llm.generate_response(messages, tools=tools) response = llm.generate_response(messages, tools=tools)
mock_gemini_client.generate_content.assert_called_once_with( # Check the actual call
contents=[ mock_gemini_client.models.generate_content.assert_called_once()
{ call_args = mock_gemini_client.models.generate_content.call_args
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
"role": "user" # Verify model and contents
}, assert call_args.kwargs['model'] == "gemini-1.5-flash-latest"
{ assert len(call_args.kwargs['contents']) == 1 # Only user message
"parts": "Add a new memory: Today is a sunny day.",
"role": "user" # Verify config has system instruction and tools
}, config_arg = call_args.kwargs['config']
], assert config_arg.system_instruction == "You are a helpful assistant."
config=types.GenerateContentConfig( assert config_arg.temperature == 0.7
temperature=0.7, assert config_arg.max_output_tokens == 100
max_output_tokens=100, assert config_arg.top_p == 1.0
top_p=1.0, assert len(config_arg.tools) == 1
tools=[ assert config_arg.tool_config.function_calling_config.mode == types.FunctionCallingConfigMode.AUTO
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name="add_memory",
description="Add a memory",
parameters={
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "Data to add to memory"
}
},
"required": ["data"]
}
)
]
)
],
tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
allowed_function_names=None,
mode="auto"
)
)
)
)
assert response["content"] == "I've added the memory for you." assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1 assert len(response["tool_calls"]) == 1

View File

@@ -43,6 +43,7 @@ def test_generate_response_without_tools(mock_lm_studio_client):
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_specifying_response_format(mock_lm_studio_client): def test_generate_response_specifying_response_format(mock_lm_studio_client):
config = BaseLlmConfig( config = BaseLlmConfig(
model="lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", model="lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
@@ -68,4 +69,4 @@ def test_generate_response_specifying_response_format(mock_lm_studio_client):
response_format={"type": "json_schema"}, response_format={"type": "json_schema"},
) )
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"

View File

@@ -71,7 +71,13 @@ def test_generate_response_with_tools(mock_vllm_client):
response = llm.generate_response(messages, tools=tools) response = llm.generate_response(messages, tools=tools)
mock_vllm_client.chat.completions.create.assert_called_once_with( mock_vllm_client.chat.completions.create.assert_called_once_with(
model="Qwen/Qwen2.5-32B-Instruct", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto" model="Qwen/Qwen2.5-32B-Instruct",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools,
tool_choice="auto",
) )
assert response["content"] == "I've added the memory for you." assert response["content"] == "I've added the memory for you."

View File

@@ -253,10 +253,10 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
def test_custom_prompts(memory_custom_instance): def test_custom_prompts(memory_custom_instance):
messages = [{"role": "user", "content": "Test message"}] messages = [{"role": "user", "content": "Test message"}]
from mem0.embeddings.mock import MockEmbeddings from mem0.embeddings.mock import MockEmbeddings
memory_custom_instance.llm.generate_response = Mock() memory_custom_instance.llm.generate_response = Mock()
memory_custom_instance.llm.generate_response.return_value = '{"facts": ["fact1", "fact2"]}' memory_custom_instance.llm.generate_response.return_value = '{"facts": ["fact1", "fact2"]}'
memory_custom_instance.embedding_model = MockEmbeddings() memory_custom_instance.embedding_model = MockEmbeddings()
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages: with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
with patch( with patch(

View File

@@ -2,10 +2,9 @@ from unittest.mock import Mock, patch, PropertyMock
import pytest import pytest
from mem0.vector_stores.baidu import BaiduDB, OutputData from mem0.vector_stores.baidu import BaiduDB
from pymochow.model.enum import MetricType, TableState, ServerErrCode from pymochow.model.enum import TableState, ServerErrCode
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement from pymochow.model.table import VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
from pymochow.exception import ServerError from pymochow.exception import ServerError

View File

@@ -1,7 +1,7 @@
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from mem0.vector_stores.mongodb import MongoDB from mem0.vector_stores.mongodb import MongoDB
from pymongo.operations import SearchIndexModel
@pytest.fixture @pytest.fixture
@patch("mem0.vector_stores.mongodb.MongoClient") @patch("mem0.vector_stores.mongodb.MongoClient")
@@ -19,10 +19,11 @@ def mongo_vector_fixture(mock_mongo_client):
db_name="test_db", db_name="test_db",
collection_name="test_collection", collection_name="test_collection",
embedding_model_dims=1536, embedding_model_dims=1536,
mongo_uri="mongodb://username:password@localhost:27017" mongo_uri="mongodb://username:password@localhost:27017",
) )
return mongo_vector, mock_collection, mock_db return mongo_vector, mock_collection, mock_db
def test_initalize_create_col(mongo_vector_fixture): def test_initalize_create_col(mongo_vector_fixture):
mongo_vector, mock_collection, mock_db = mongo_vector_fixture mongo_vector, mock_collection, mock_db = mongo_vector_fixture
assert mongo_vector.collection_name == "test_collection" assert mongo_vector.collection_name == "test_collection"
@@ -49,12 +50,13 @@ def test_initalize_create_col(mongo_vector_fixture):
"dimensions": 1536, "dimensions": 1536,
"similarity": "cosine", "similarity": "cosine",
} }
} },
} }
} },
} }
assert mongo_vector.collection == mock_collection assert mongo_vector.collection == mock_collection
def test_insert(mongo_vector_fixture): def test_insert(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
vectors = [[0.1] * 1536, [0.2] * 1536] vectors = [[0.1] * 1536, [0.2] * 1536]
@@ -62,12 +64,13 @@ def test_insert(mongo_vector_fixture):
ids = ["id1", "id2"] ids = ["id1", "id2"]
mongo_vector.insert(vectors, payloads, ids) mongo_vector.insert(vectors, payloads, ids)
expected_records=[ expected_records = [
({"_id": ids[0], "embedding": vectors[0], "payload": payloads[0]}), ({"_id": ids[0], "embedding": vectors[0], "payload": payloads[0]}),
({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]}) ({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]}),
] ]
mock_collection.insert_many.assert_called_once_with(expected_records) mock_collection.insert_many.assert_called_once_with(expected_records)
def test_search(mongo_vector_fixture): def test_search(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
query_vector = [0.1] * 1536 query_vector = [0.1] * 1536
@@ -79,25 +82,28 @@ def test_search(mongo_vector_fixture):
results = mongo_vector.search("query_str", query_vector, limit=2) results = mongo_vector.search("query_str", query_vector, limit=2)
mock_collection.list_search_indexes.assert_called_with(name="test_collection_vector_index") mock_collection.list_search_indexes.assert_called_with(name="test_collection_vector_index")
mock_collection.aggregate.assert_called_once_with([ mock_collection.aggregate.assert_called_once_with(
{ [
"$vectorSearch": { {
"index": "test_collection_vector_index", "$vectorSearch": {
"limit": 2, "index": "test_collection_vector_index",
"numCandidates": 2, "limit": 2,
"queryVector": query_vector, "numCandidates": 2,
"path": "embedding", "queryVector": query_vector,
"path": "embedding",
},
}, },
}, {"$set": {"score": {"$meta": "vectorSearchScore"}}},
{"$set": {"score": {"$meta": "vectorSearchScore"}}}, {"$project": {"embedding": 0}},
{"$project": {"embedding": 0}}, ]
]) )
assert len(results) == 2 assert len(results) == 2
assert results[0].id == "id1" assert results[0].id == "id1"
assert results[0].score == 0.9 assert results[0].score == 0.9
assert results[1].id == "id2" assert results[1].id == "id2"
assert results[1].score == 0.8 assert results[1].score == 0.8
def test_delete(mongo_vector_fixture): def test_delete(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_delete_result = MagicMock() mock_delete_result = MagicMock()
@@ -107,6 +113,7 @@ def test_delete(mongo_vector_fixture):
mongo_vector.delete("id1") mongo_vector.delete("id1")
mock_collection.delete_one.assert_called_with({"_id": "id1"}) mock_collection.delete_one.assert_called_with({"_id": "id1"})
def test_update(mongo_vector_fixture): def test_update(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_update_result = MagicMock() mock_update_result = MagicMock()
@@ -122,6 +129,7 @@ def test_update(mongo_vector_fixture):
{"$set": {"embedding": vectorValue, "payload": payloadValue}}, {"$set": {"embedding": vectorValue, "payload": payloadValue}},
) )
def test_get(mongo_vector_fixture): def test_get(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_collection.find_one.return_value = {"_id": "id1", "payload": {"key": "value1"}} mock_collection.find_one.return_value = {"_id": "id1", "payload": {"key": "value1"}}
@@ -131,6 +139,7 @@ def test_get(mongo_vector_fixture):
assert result.id == "id1" assert result.id == "id1"
assert result.payload == {"key": "value1"} assert result.payload == {"key": "value1"}
def test_list_cols(mongo_vector_fixture): def test_list_cols(mongo_vector_fixture):
mongo_vector, _, mock_db = mongo_vector_fixture mongo_vector, _, mock_db = mongo_vector_fixture
mock_db.list_collection_names.return_value = ["col1", "col2"] mock_db.list_collection_names.return_value = ["col1", "col2"]
@@ -138,12 +147,14 @@ def test_list_cols(mongo_vector_fixture):
collections = mongo_vector.list_cols() collections = mongo_vector.list_cols()
assert collections == ["col1", "col2"] assert collections == ["col1", "col2"]
def test_delete_col(mongo_vector_fixture): def test_delete_col(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
mongo_vector.delete_col() mongo_vector.delete_col()
mock_collection.drop.assert_called_once() mock_collection.drop.assert_called_once()
def test_col_info(mongo_vector_fixture): def test_col_info(mongo_vector_fixture):
mongo_vector, _, mock_db = mongo_vector_fixture mongo_vector, _, mock_db = mongo_vector_fixture
mock_db.command.return_value = {"count": 10, "size": 1024} mock_db.command.return_value = {"count": 10, "size": 1024}
@@ -154,6 +165,7 @@ def test_col_info(mongo_vector_fixture):
assert info["count"] == 10 assert info["count"] == 10
assert info["size"] == 1024 assert info["size"] == 1024
def test_list(mongo_vector_fixture): def test_list(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_cursor = MagicMock() mock_cursor = MagicMock()

View File

@@ -88,7 +88,7 @@ def test_get_vector_found(pinecone_db):
# or a list of dictionaries, not a dictionary with an 'id' field # or a list of dictionaries, not a dictionary with an 'id' field
# Create a mock Vector object # Create a mock Vector object
from pinecone.data.dataclasses.vector import Vector from pinecone import Vector
mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"}) mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"})