Fix CI issues related to missing dependency (#3096)
This commit is contained in:
7
.github/workflows/ci.yml
vendored
7
.github/workflows/ci.yml
vendored
@@ -58,11 +58,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -e ".[test,graph,vector_stores,llms,extras]"
|
pip install -e ".[test,graph,vector_stores,llms,extras]"
|
||||||
|
pip install ruff
|
||||||
if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true'
|
if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true'
|
||||||
- name: Run Formatting
|
- name: Run Linting
|
||||||
run: |
|
run: make lint
|
||||||
mkdir -p .ruff_cache && chmod -R 777 .ruff_cache
|
|
||||||
hatch run format
|
|
||||||
- name: Run tests and generate coverage report
|
- name: Run tests and generate coverage report
|
||||||
run: make test
|
run: make test
|
||||||
|
|
||||||
|
|||||||
2
Makefile
2
Makefile
@@ -13,7 +13,7 @@ install:
|
|||||||
install_all:
|
install_all:
|
||||||
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
||||||
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
|
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
|
||||||
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow
|
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow pymongo
|
||||||
|
|
||||||
# Format code with ruff
|
# Format code with ruff
|
||||||
format:
|
format:
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ def main():
|
|||||||
print("All categories accuracy:")
|
print("All categories accuracy:")
|
||||||
for cat, results in LLM_JUDGE.items():
|
for cat, results in LLM_JUDGE.items():
|
||||||
if results: # Only print if there are results for this category
|
if results: # Only print if there are results for this category
|
||||||
print(f" Category {cat}: {np.mean(results):.4f} " f"({sum(results)}/{len(results)})")
|
print(f" Category {cat}: {np.mean(results):.4f} ({sum(results)}/{len(results)})")
|
||||||
print("------------------------------------------")
|
print("------------------------------------------")
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class RAGManager:
|
|||||||
def clean_chat_history(self, chat_history):
|
def clean_chat_history(self, chat_history):
|
||||||
cleaned_chat_history = ""
|
cleaned_chat_history = ""
|
||||||
for c in chat_history:
|
for c in chat_history:
|
||||||
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n"
|
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: {c['text']}\n"
|
||||||
|
|
||||||
return cleaned_chat_history
|
return cleaned_chat_history
|
||||||
|
|
||||||
|
|||||||
@@ -1,271 +1,267 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "ApdaLD4Qi30H"
|
"id": "ApdaLD4Qi30H"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# Neo4j as Graph Memory"
|
"# Neo4j as Graph Memory"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "l7bi3i21i30I"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Prerequisites\n",
|
|
||||||
"\n",
|
|
||||||
"### 1. Install Mem0 with Graph Memory support\n",
|
|
||||||
"\n",
|
|
||||||
"To use Mem0 with Graph Memory support, install it using pip:\n",
|
|
||||||
"\n",
|
|
||||||
"```bash\n",
|
|
||||||
"pip install \"mem0ai[graph]\"\n",
|
|
||||||
"```\n",
|
|
||||||
"\n",
|
|
||||||
"This command installs Mem0 along with the necessary dependencies for graph functionality.\n",
|
|
||||||
"\n",
|
|
||||||
"### 2. Install Neo4j\n",
|
|
||||||
"\n",
|
|
||||||
"To utilize Neo4j as Graph Memory, run it with Docker:\n",
|
|
||||||
"\n",
|
|
||||||
"```bash\n",
|
|
||||||
"docker run \\\n",
|
|
||||||
" -p 7474:7474 -p 7687:7687 \\\n",
|
|
||||||
" -e NEO4J_AUTH=neo4j/password \\\n",
|
|
||||||
" neo4j:5\n",
|
|
||||||
"```\n",
|
|
||||||
"\n",
|
|
||||||
"This command starts Neo4j with default credentials (`neo4j` / `password`) and exposes both the HTTP (7474) and Bolt (7687) ports.\n",
|
|
||||||
"\n",
|
|
||||||
"You can access the Neo4j browser at [http://localhost:7474](http://localhost:7474).\n",
|
|
||||||
"\n",
|
|
||||||
"Additional information can be found in the [Neo4j documentation](https://neo4j.com/docs/).\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "DkeBdFEpi30I"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Configuration\n",
|
|
||||||
"\n",
|
|
||||||
"Do all the imports and configure OpenAI (enter your OpenAI API key):"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {
|
|
||||||
"id": "d99EfBpii30I"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from mem0 import Memory\n",
|
|
||||||
"\n",
|
|
||||||
"import os\n",
|
|
||||||
"\n",
|
|
||||||
"os.environ[\"OPENAI_API_KEY\"] = (\n",
|
|
||||||
" \"\"\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "QTucZJjIi30J"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"Set up configuration to use the embedder model and Neo4j as a graph store:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {
|
|
||||||
"id": "QSE0RFoSi30J"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"config = {\n",
|
|
||||||
" \"embedder\": {\n",
|
|
||||||
" \"provider\": \"openai\",\n",
|
|
||||||
" \"config\": {\"model\": \"text-embedding-3-large\", \"embedding_dims\": 1536},\n",
|
|
||||||
" },\n",
|
|
||||||
" \"graph_store\": {\n",
|
|
||||||
" \"provider\": \"neo4j\",\n",
|
|
||||||
" \"config\": {\n",
|
|
||||||
" \"url\": \"bolt://54.87.227.131:7687\",\n",
|
|
||||||
" \"username\": \"neo4j\",\n",
|
|
||||||
" \"password\": \"causes-bins-vines\",\n",
|
|
||||||
" },\n",
|
|
||||||
" },\n",
|
|
||||||
"}"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "OioTnv6xi30J"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Graph Memory initializiation\n",
|
|
||||||
"\n",
|
|
||||||
"Initialize Neo4j as a Graph Memory store:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {
|
|
||||||
"id": "fX-H9vgNi30J"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"m = Memory.from_config(config_dict=config)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "kr1fVMwEi30J"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Store memories\n",
|
|
||||||
"\n",
|
|
||||||
"Create memories:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {
|
|
||||||
"id": "sEfogqp_i30J"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"messages = [\n",
|
|
||||||
" {\n",
|
|
||||||
" \"role\": \"user\",\n",
|
|
||||||
" \"content\": \"I'm planning to watch a movie tonight. Any recommendations?\",\n",
|
|
||||||
" },\n",
|
|
||||||
" {\n",
|
|
||||||
" \"role\": \"assistant\",\n",
|
|
||||||
" \"content\": \"How about a thriller movies? They can be quite engaging.\",\n",
|
|
||||||
" },\n",
|
|
||||||
" {\n",
|
|
||||||
" \"role\": \"user\",\n",
|
|
||||||
" \"content\": \"I'm not a big fan of thriller movies but I love sci-fi movies.\",\n",
|
|
||||||
" },\n",
|
|
||||||
" {\n",
|
|
||||||
" \"role\": \"assistant\",\n",
|
|
||||||
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
|
|
||||||
" },\n",
|
|
||||||
"]\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "gtBHCyIgi30J"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"Store memories in Neo4j:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {
|
|
||||||
"id": "BMVGgZMFi30K"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Store inferred memories (default behavior)\n",
|
|
||||||
"result = m.add(\n",
|
|
||||||
" messages, user_id=\"alice\"\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "lQRptOywi30K"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
""
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "LBXW7Gv-i30K"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Search memories"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "UHFDeQBEi30K",
|
|
||||||
"outputId": "2c69de7d-a79a-48f6-e3c4-bd743067857c"
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Loves sci-fi movies 0.3153664287340898\n",
|
|
||||||
"Planning to watch a movie tonight 0.09683349296551162\n",
|
|
||||||
"Not a big fan of thriller movies 0.09468540071789466\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"for result in m.search(\"what does alice love?\", user_id=\"alice\")[\"results\"]:\n",
|
|
||||||
" print(result[\"memory\"], result[\"score\"])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {
|
|
||||||
"id": "2jXEIma9kK_Q"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"provenance": []
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": ".venv",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.13.2"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
{
|
||||||
"nbformat_minor": 0
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "l7bi3i21i30I"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Prerequisites\n",
|
||||||
|
"\n",
|
||||||
|
"### 1. Install Mem0 with Graph Memory support\n",
|
||||||
|
"\n",
|
||||||
|
"To use Mem0 with Graph Memory support, install it using pip:\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"pip install \"mem0ai[graph]\"\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"This command installs Mem0 along with the necessary dependencies for graph functionality.\n",
|
||||||
|
"\n",
|
||||||
|
"### 2. Install Neo4j\n",
|
||||||
|
"\n",
|
||||||
|
"To utilize Neo4j as Graph Memory, run it with Docker:\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"docker run \\\n",
|
||||||
|
" -p 7474:7474 -p 7687:7687 \\\n",
|
||||||
|
" -e NEO4J_AUTH=neo4j/password \\\n",
|
||||||
|
" neo4j:5\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"This command starts Neo4j with default credentials (`neo4j` / `password`) and exposes both the HTTP (7474) and Bolt (7687) ports.\n",
|
||||||
|
"\n",
|
||||||
|
"You can access the Neo4j browser at [http://localhost:7474](http://localhost:7474).\n",
|
||||||
|
"\n",
|
||||||
|
"Additional information can be found in the [Neo4j documentation](https://neo4j.com/docs/).\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "DkeBdFEpi30I"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Configuration\n",
|
||||||
|
"\n",
|
||||||
|
"Do all the imports and configure OpenAI (enter your OpenAI API key):"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {
|
||||||
|
"id": "d99EfBpii30I"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from mem0 import Memory\n",
|
||||||
|
"\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"OPENAI_API_KEY\"] = \"\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "QTucZJjIi30J"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"Set up configuration to use the embedder model and Neo4j as a graph store:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {
|
||||||
|
"id": "QSE0RFoSi30J"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"config = {\n",
|
||||||
|
" \"embedder\": {\n",
|
||||||
|
" \"provider\": \"openai\",\n",
|
||||||
|
" \"config\": {\"model\": \"text-embedding-3-large\", \"embedding_dims\": 1536},\n",
|
||||||
|
" },\n",
|
||||||
|
" \"graph_store\": {\n",
|
||||||
|
" \"provider\": \"neo4j\",\n",
|
||||||
|
" \"config\": {\n",
|
||||||
|
" \"url\": \"bolt://54.87.227.131:7687\",\n",
|
||||||
|
" \"username\": \"neo4j\",\n",
|
||||||
|
" \"password\": \"causes-bins-vines\",\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "OioTnv6xi30J"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Graph Memory initializiation\n",
|
||||||
|
"\n",
|
||||||
|
"Initialize Neo4j as a Graph Memory store:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {
|
||||||
|
"id": "fX-H9vgNi30J"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"m = Memory.from_config(config_dict=config)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "kr1fVMwEi30J"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Store memories\n",
|
||||||
|
"\n",
|
||||||
|
"Create memories:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {
|
||||||
|
"id": "sEfogqp_i30J"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"user\",\n",
|
||||||
|
" \"content\": \"I'm planning to watch a movie tonight. Any recommendations?\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"assistant\",\n",
|
||||||
|
" \"content\": \"How about a thriller movies? They can be quite engaging.\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"user\",\n",
|
||||||
|
" \"content\": \"I'm not a big fan of thriller movies but I love sci-fi movies.\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"assistant\",\n",
|
||||||
|
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
|
||||||
|
" },\n",
|
||||||
|
"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "gtBHCyIgi30J"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"Store memories in Neo4j:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {
|
||||||
|
"id": "BMVGgZMFi30K"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Store inferred memories (default behavior)\n",
|
||||||
|
"result = m.add(messages, user_id=\"alice\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "lQRptOywi30K"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "LBXW7Gv-i30K"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Search memories"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "UHFDeQBEi30K",
|
||||||
|
"outputId": "2c69de7d-a79a-48f6-e3c4-bd743067857c"
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Loves sci-fi movies 0.3153664287340898\n",
|
||||||
|
"Planning to watch a movie tonight 0.09683349296551162\n",
|
||||||
|
"Not a big fan of thriller movies 0.09468540071789466\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"for result in m.search(\"what does alice love?\", user_id=\"alice\")[\"results\"]:\n",
|
||||||
|
" print(result[\"memory\"], result[\"score\"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {
|
||||||
|
"id": "2jXEIma9kK_Q"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.13.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,11 +45,7 @@ def get_food_recommendation(user_query: str, user_id):
|
|||||||
"""Get food recommendation with memory context"""
|
"""Get food recommendation with memory context"""
|
||||||
|
|
||||||
# Search memory for relevant food preferences
|
# Search memory for relevant food preferences
|
||||||
memories_result = memory_client.search(
|
memories_result = memory_client.search(query=user_query, user_id=user_id, limit=5)
|
||||||
query=user_query,
|
|
||||||
user_id=user_id,
|
|
||||||
limit=5
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add memory context to the message
|
# Add memory context to the message
|
||||||
memories = [f"- {result['memory']}" for result in memories_result]
|
memories = [f"- {result['memory']}" for result in memories_result]
|
||||||
@@ -71,6 +67,7 @@ def get_food_recommendation(user_query: str, user_id):
|
|||||||
# Save audio file
|
# Save audio file
|
||||||
if response.audio:
|
if response.audio:
|
||||||
import time
|
import time
|
||||||
|
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
filename = f"food_recommendation_{timestamp}.mp3"
|
filename = f"food_recommendation_{timestamp}.mp3"
|
||||||
write_audio_to_file(
|
write_audio_to_file(
|
||||||
@@ -118,7 +115,11 @@ def initialize_food_memory(user_id):
|
|||||||
# Initialize the memory for the user once in order for the agent to learn the user preference
|
# Initialize the memory for the user once in order for the agent to learn the user preference
|
||||||
initialize_food_memory(user_id=USER_ID)
|
initialize_food_memory(user_id=USER_ID)
|
||||||
|
|
||||||
print(get_food_recommendation("Which type of restaurants should I go tonight for dinner and cuisines preferred?", user_id=USER_ID))
|
print(
|
||||||
|
get_food_recommendation(
|
||||||
|
"Which type of restaurants should I go tonight for dinner and cuisines preferred?", user_id=USER_ID
|
||||||
|
)
|
||||||
|
)
|
||||||
# OUTPUT: 🎵 Audio saved as food_recommendation_1750162610.mp3
|
# OUTPUT: 🎵 Audio saved as food_recommendation_1750162610.mp3
|
||||||
# For dinner tonight, considering your love for healthy spic optionsy, you could try a nice Thai, Indian, or Mexican restaurant.
|
# For dinner tonight, considering your love for healthy spic optionsy, you could try a nice Thai, Indian, or Mexican restaurant.
|
||||||
# You might find dishes with quinoa, chickpeas, tofu, and fresh herbs delightful. Enjoy your dinner!
|
# You might find dishes with quinoa, chickpeas, tofu, and fresh herbs delightful. Enjoy your dinner!
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from agents import Agent, Runner, function_tool, handoffs, enable_verbose_stdout_logging
|
from agents import Agent, Runner, function_tool, enable_verbose_stdout_logging
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from mem0 import MemoryClient
|
from mem0 import MemoryClient
|
||||||
@@ -35,7 +35,7 @@ travel_agent = Agent(
|
|||||||
understand the user's travel preferences and history before making recommendations.
|
understand the user's travel preferences and history before making recommendations.
|
||||||
After providing your response, use store_conversation to save important details.""",
|
After providing your response, use store_conversation to save important details.""",
|
||||||
tools=[search_memory, save_memory],
|
tools=[search_memory, save_memory],
|
||||||
model="gpt-4o"
|
model="gpt-4o",
|
||||||
)
|
)
|
||||||
|
|
||||||
health_agent = Agent(
|
health_agent = Agent(
|
||||||
@@ -44,7 +44,7 @@ health_agent = Agent(
|
|||||||
understand the user's health goals and dietary preferences.
|
understand the user's health goals and dietary preferences.
|
||||||
After providing advice, use store_conversation to save relevant information.""",
|
After providing advice, use store_conversation to save relevant information.""",
|
||||||
tools=[search_memory, save_memory],
|
tools=[search_memory, save_memory],
|
||||||
model="gpt-4o"
|
model="gpt-4o",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Triage agent with handoffs
|
# Triage agent with handoffs
|
||||||
@@ -55,7 +55,7 @@ triage_agent = Agent(
|
|||||||
For health-related questions (fitness, diet, wellness, exercise), hand off to Health Advisor.
|
For health-related questions (fitness, diet, wellness, exercise), hand off to Health Advisor.
|
||||||
For general questions, you can handle them directly using available tools.""",
|
For general questions, you can handle them directly using available tools.""",
|
||||||
handoffs=[travel_agent, health_agent],
|
handoffs=[travel_agent, health_agent],
|
||||||
model="gpt-4o"
|
model="gpt-4o",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -74,10 +74,7 @@ def chat_with_handoffs(user_input: str, user_id: str) -> str:
|
|||||||
result = Runner.run_sync(triage_agent, user_input)
|
result = Runner.run_sync(triage_agent, user_input)
|
||||||
|
|
||||||
# Store the original conversation in memory
|
# Store the original conversation in memory
|
||||||
conversation = [
|
conversation = [{"role": "user", "content": user_input}, {"role": "assistant", "content": result.final_output}]
|
||||||
{"role": "user", "content": user_input},
|
|
||||||
{"role": "assistant", "content": result.final_output}
|
|
||||||
]
|
|
||||||
mem0.add(conversation, user_id=user_id)
|
mem0.add(conversation, user_id=user_id)
|
||||||
|
|
||||||
return result.final_output
|
return result.final_output
|
||||||
|
|||||||
@@ -34,96 +34,91 @@ config = {
|
|||||||
"api_key": "vllm-api-key",
|
"api_key": "vllm-api-key",
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
}
|
},
|
||||||
},
|
|
||||||
"embedder": {
|
|
||||||
"provider": "openai",
|
|
||||||
"config": {
|
|
||||||
"model": "text-embedding-3-small"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
"embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}},
|
||||||
"vector_store": {
|
"vector_store": {
|
||||||
"provider": "qdrant",
|
"provider": "qdrant",
|
||||||
"config": {
|
"config": {"collection_name": "vllm_memories", "host": "localhost", "port": 6333},
|
||||||
"collection_name": "vllm_memories",
|
},
|
||||||
"host": "localhost",
|
|
||||||
"port": 6333
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
"""
|
||||||
Demonstrate vLLM integration with mem0
|
Demonstrate vLLM integration with mem0
|
||||||
"""
|
"""
|
||||||
print("--> Initializing mem0 with vLLM...")
|
print("--> Initializing mem0 with vLLM...")
|
||||||
|
|
||||||
# Initialize memory with vLLM
|
# Initialize memory with vLLM
|
||||||
memory = Memory.from_config(config)
|
memory = Memory.from_config(config)
|
||||||
|
|
||||||
print("--> Memory initialized successfully!")
|
print("--> Memory initialized successfully!")
|
||||||
|
|
||||||
# Example conversations to store
|
# Example conversations to store
|
||||||
conversations = [
|
conversations = [
|
||||||
{
|
{
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "I love playing chess on weekends"},
|
{"role": "user", "content": "I love playing chess on weekends"},
|
||||||
{"role": "assistant", "content": "That's great! Chess is an excellent strategic game that helps improve critical thinking."}
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "That's great! Chess is an excellent strategic game that helps improve critical thinking.",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
"user_id": "user_123"
|
"user_id": "user_123",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "I'm learning Python programming"},
|
{"role": "user", "content": "I'm learning Python programming"},
|
||||||
{"role": "assistant", "content": "Python is a fantastic language for beginners! What specific areas are you focusing on?"}
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Python is a fantastic language for beginners! What specific areas are you focusing on?",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
"user_id": "user_123"
|
"user_id": "user_123",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "I prefer working late at night, I'm more productive then"},
|
{"role": "user", "content": "I prefer working late at night, I'm more productive then"},
|
||||||
{"role": "assistant", "content": "Many people find they're more creative and focused during nighttime hours. It's important to maintain a consistent schedule that works for you."}
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Many people find they're more creative and focused during nighttime hours. It's important to maintain a consistent schedule that works for you.",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
"user_id": "user_123"
|
"user_id": "user_123",
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
print("\n--> Adding memories using vLLM...")
|
print("\n--> Adding memories using vLLM...")
|
||||||
|
|
||||||
# Add memories - now powered by vLLM's high-performance inference
|
# Add memories - now powered by vLLM's high-performance inference
|
||||||
for i, conversation in enumerate(conversations, 1):
|
for i, conversation in enumerate(conversations, 1):
|
||||||
result = memory.add(
|
result = memory.add(messages=conversation["messages"], user_id=conversation["user_id"])
|
||||||
messages=conversation["messages"],
|
|
||||||
user_id=conversation["user_id"]
|
|
||||||
)
|
|
||||||
print(f"Memory {i} added: {result}")
|
print(f"Memory {i} added: {result}")
|
||||||
|
|
||||||
print("\n🔍 Searching memories...")
|
print("\n🔍 Searching memories...")
|
||||||
|
|
||||||
# Search memories - vLLM will process the search and memory operations
|
# Search memories - vLLM will process the search and memory operations
|
||||||
search_queries = [
|
search_queries = [
|
||||||
"What does the user like to do on weekends?",
|
"What does the user like to do on weekends?",
|
||||||
"What is the user learning?",
|
"What is the user learning?",
|
||||||
"When is the user most productive?"
|
"When is the user most productive?",
|
||||||
]
|
]
|
||||||
|
|
||||||
for query in search_queries:
|
for query in search_queries:
|
||||||
print(f"\nQuery: {query}")
|
print(f"\nQuery: {query}")
|
||||||
memories = memory.search(
|
memories = memory.search(query=query, user_id="user_123")
|
||||||
query=query,
|
|
||||||
user_id="user_123"
|
|
||||||
)
|
|
||||||
|
|
||||||
for memory_item in memories:
|
for memory_item in memories:
|
||||||
print(f" - {memory_item['memory']}")
|
print(f" - {memory_item['memory']}")
|
||||||
|
|
||||||
print("\n--> Getting all memories for user...")
|
print("\n--> Getting all memories for user...")
|
||||||
all_memories = memory.get_all(user_id="user_123")
|
all_memories = memory.get_all(user_id="user_123")
|
||||||
print(f"Total memories stored: {len(all_memories)}")
|
print(f"Total memories stored: {len(all_memories)}")
|
||||||
|
|
||||||
for memory_item in all_memories:
|
for memory_item in all_memories:
|
||||||
print(f" - {memory_item['memory']}")
|
print(f" - {memory_item['memory']}")
|
||||||
|
|
||||||
print("\n--> vLLM integration demo completed successfully!")
|
print("\n--> vLLM integration demo completed successfully!")
|
||||||
print("\nBenefits of using vLLM:")
|
print("\nBenefits of using vLLM:")
|
||||||
print(" -> 2.7x higher throughput compared to standard implementations")
|
print(" -> 2.7x higher throughput compared to standard implementations")
|
||||||
|
|||||||
@@ -89,9 +89,7 @@ class MemoryClient:
|
|||||||
self.user_id = get_user_id()
|
self.user_id = get_user_id()
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError(
|
raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
|
||||||
"Mem0 API Key not provided. Please provide an API Key."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create MD5 hash of API key for user_id
|
# Create MD5 hash of API key for user_id
|
||||||
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
|
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
|
||||||
@@ -174,9 +172,7 @@ class MemoryClient:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
if "metadata" in kwargs:
|
if "metadata" in kwargs:
|
||||||
del kwargs["metadata"]
|
del kwargs["metadata"]
|
||||||
capture_client_event(
|
capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"})
|
||||||
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "sync"}
|
|
||||||
)
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
@@ -195,9 +191,7 @@ class MemoryClient:
|
|||||||
params = self._prepare_params()
|
params = self._prepare_params()
|
||||||
response = self.client.get(f"/v1/memories/{memory_id}/", params=params)
|
response = self.client.get(f"/v1/memories/{memory_id}/", params=params)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event(
|
capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||||
"client.get", self, {"memory_id": memory_id, "sync_type": "sync"}
|
|
||||||
)
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
@@ -224,13 +218,9 @@ class MemoryClient:
|
|||||||
"page": params.pop("page"),
|
"page": params.pop("page"),
|
||||||
"page_size": params.pop("page_size"),
|
"page_size": params.pop("page_size"),
|
||||||
}
|
}
|
||||||
response = self.client.post(
|
response = self.client.post(f"/{version}/memories/", json=params, params=query_params)
|
||||||
f"/{version}/memories/", json=params, params=query_params
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
response = self.client.post(
|
response = self.client.post(f"/{version}/memories/", json=params)
|
||||||
f"/{version}/memories/", json=params
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
if "metadata" in kwargs:
|
if "metadata" in kwargs:
|
||||||
del kwargs["metadata"]
|
del kwargs["metadata"]
|
||||||
@@ -246,9 +236,7 @@ class MemoryClient:
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
def search(
|
def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||||
self, query: str, version: str = "v1", **kwargs
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Search memories based on a query.
|
"""Search memories based on a query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -266,9 +254,7 @@ class MemoryClient:
|
|||||||
payload = {"query": query}
|
payload = {"query": query}
|
||||||
params = self._prepare_params(kwargs)
|
params = self._prepare_params(kwargs)
|
||||||
payload.update(params)
|
payload.update(params)
|
||||||
response = self.client.post(
|
response = self.client.post(f"/{version}/memories/search/", json=payload)
|
||||||
f"/{version}/memories/search/", json=payload
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
if "metadata" in kwargs:
|
if "metadata" in kwargs:
|
||||||
del kwargs["metadata"]
|
del kwargs["metadata"]
|
||||||
@@ -308,13 +294,9 @@ class MemoryClient:
|
|||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
payload["metadata"] = metadata
|
payload["metadata"] = metadata
|
||||||
|
|
||||||
capture_client_event(
|
capture_client_event("client.update", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||||
"client.update", self, {"memory_id": memory_id, "sync_type": "sync"}
|
|
||||||
)
|
|
||||||
params = self._prepare_params()
|
params = self._prepare_params()
|
||||||
response = self.client.put(
|
response = self.client.put(f"/v1/memories/{memory_id}/", json=payload, params=params)
|
||||||
f"/v1/memories/{memory_id}/", json=payload, params=params
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@@ -332,13 +314,9 @@ class MemoryClient:
|
|||||||
APIError: If the API request fails.
|
APIError: If the API request fails.
|
||||||
"""
|
"""
|
||||||
params = self._prepare_params()
|
params = self._prepare_params()
|
||||||
response = self.client.delete(
|
response = self.client.delete(f"/v1/memories/{memory_id}/", params=params)
|
||||||
f"/v1/memories/{memory_id}/", params=params
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event(
|
capture_client_event("client.delete", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||||
"client.delete", self, {"memory_id": memory_id, "sync_type": "sync"}
|
|
||||||
)
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
@@ -379,13 +357,9 @@ class MemoryClient:
|
|||||||
APIError: If the API request fails.
|
APIError: If the API request fails.
|
||||||
"""
|
"""
|
||||||
params = self._prepare_params()
|
params = self._prepare_params()
|
||||||
response = self.client.get(
|
response = self.client.get(f"/v1/memories/{memory_id}/history/", params=params)
|
||||||
f"/v1/memories/{memory_id}/history/", params=params
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event(
|
capture_client_event("client.history", self, {"memory_id": memory_id, "sync_type": "sync"})
|
||||||
"client.history", self, {"memory_id": memory_id, "sync_type": "sync"}
|
|
||||||
)
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
@@ -432,10 +406,7 @@ class MemoryClient:
|
|||||||
else:
|
else:
|
||||||
entities = self.users()
|
entities = self.users()
|
||||||
# Filter entities based on provided IDs using list comprehension
|
# Filter entities based on provided IDs using list comprehension
|
||||||
to_delete = [
|
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
|
||||||
{"type": entity["type"], "name": entity["name"]}
|
|
||||||
for entity in entities["results"]
|
|
||||||
]
|
|
||||||
|
|
||||||
params = self._prepare_params()
|
params = self._prepare_params()
|
||||||
|
|
||||||
@@ -444,9 +415,7 @@ class MemoryClient:
|
|||||||
|
|
||||||
# Delete entities and check response immediately
|
# Delete entities and check response immediately
|
||||||
for entity in to_delete:
|
for entity in to_delete:
|
||||||
response = self.client.delete(
|
response = self.client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
|
||||||
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
capture_client_event(
|
capture_client_event(
|
||||||
@@ -484,9 +453,7 @@ class MemoryClient:
|
|||||||
self.delete_users()
|
self.delete_users()
|
||||||
|
|
||||||
capture_client_event("client.reset", self, {"sync_type": "sync"})
|
capture_client_event("client.reset", self, {"sync_type": "sync"})
|
||||||
return {
|
return {"message": "Client reset successful. All users and memories deleted."}
|
||||||
"message": "Client reset successful. All users and memories deleted."
|
|
||||||
}
|
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
def batch_update(self, memories: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def batch_update(self, memories: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
@@ -507,9 +474,7 @@ class MemoryClient:
|
|||||||
response = self.client.put("/v1/batch/", json={"memories": memories})
|
response = self.client.put("/v1/batch/", json={"memories": memories})
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
capture_client_event(
|
capture_client_event("client.batch_update", self, {"sync_type": "sync"})
|
||||||
"client.batch_update", self, {"sync_type": "sync"}
|
|
||||||
)
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
@@ -527,14 +492,10 @@ class MemoryClient:
|
|||||||
Raises:
|
Raises:
|
||||||
APIError: If the API request fails.
|
APIError: If the API request fails.
|
||||||
"""
|
"""
|
||||||
response = self.client.request(
|
response = self.client.request("DELETE", "/v1/batch/", json={"memories": memories})
|
||||||
"DELETE", "/v1/batch/", json={"memories": memories}
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
capture_client_event(
|
capture_client_event("client.batch_delete", self, {"sync_type": "sync"})
|
||||||
"client.batch_delete", self, {"sync_type": "sync"}
|
|
||||||
)
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
@@ -574,9 +535,7 @@ class MemoryClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict containing the exported data
|
Dict containing the exported data
|
||||||
"""
|
"""
|
||||||
response = self.client.post(
|
response = self.client.post("/v1/exports/get/", json=self._prepare_params(kwargs))
|
||||||
"/v1/exports/get/", json=self._prepare_params(kwargs)
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event(
|
capture_client_event(
|
||||||
"client.get_memory_export",
|
"client.get_memory_export",
|
||||||
@@ -586,9 +545,7 @@ class MemoryClient:
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
def get_summary(
|
def get_summary(self, filters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
self, filters: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Get the summary of a memory export.
|
"""Get the summary of a memory export.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -598,17 +555,13 @@ class MemoryClient:
|
|||||||
Dict containing the export status and summary data
|
Dict containing the export status and summary data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response = self.client.post(
|
response = self.client.post("/v1/summary/", json=self._prepare_params({"filters": filters}))
|
||||||
"/v1/summary/", json=self._prepare_params({"filters": filters})
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event("client.get_summary", self, {"sync_type": "sync"})
|
capture_client_event("client.get_summary", self, {"sync_type": "sync"})
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
def get_project(
|
def get_project(self, fields: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||||
self, fields: Optional[List[str]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Get instructions or categories for the current project.
|
"""Get instructions or categories for the current project.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -622,10 +575,7 @@ class MemoryClient:
|
|||||||
ValueError: If org_id or project_id are not set.
|
ValueError: If org_id or project_id are not set.
|
||||||
"""
|
"""
|
||||||
if not (self.org_id and self.project_id):
|
if not (self.org_id and self.project_id):
|
||||||
raise ValueError(
|
raise ValueError("org_id and project_id must be set to access instructions or categories")
|
||||||
"org_id and project_id must be set to access instructions or "
|
|
||||||
"categories"
|
|
||||||
)
|
|
||||||
|
|
||||||
params = self._prepare_params({"fields": fields})
|
params = self._prepare_params({"fields": fields})
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
@@ -666,10 +616,7 @@ class MemoryClient:
|
|||||||
ValueError: If org_id or project_id are not set.
|
ValueError: If org_id or project_id are not set.
|
||||||
"""
|
"""
|
||||||
if not (self.org_id and self.project_id):
|
if not (self.org_id and self.project_id):
|
||||||
raise ValueError(
|
raise ValueError("org_id and project_id must be set to update instructions or categories")
|
||||||
"org_id and project_id must be set to update instructions or "
|
|
||||||
"categories"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
custom_instructions is None
|
custom_instructions is None
|
||||||
@@ -826,10 +773,7 @@ class MemoryClient:
|
|||||||
|
|
||||||
feedback = feedback.upper() if feedback else None
|
feedback = feedback.upper() if feedback else None
|
||||||
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
||||||
raise ValueError(
|
raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
|
||||||
f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} '
|
|
||||||
"or None"
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"memory_id": memory_id,
|
"memory_id": memory_id,
|
||||||
@@ -839,14 +783,10 @@ class MemoryClient:
|
|||||||
|
|
||||||
response = self.client.post("/v1/feedback/", json=data)
|
response = self.client.post("/v1/feedback/", json=data)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event(
|
capture_client_event("client.feedback", self, data, {"sync_type": "sync"})
|
||||||
"client.feedback", self, data, {"sync_type": "sync"}
|
|
||||||
)
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
def _prepare_payload(
|
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Prepare the payload for API requests.
|
"""Prepare the payload for API requests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -862,9 +802,7 @@ class MemoryClient:
|
|||||||
payload.update({k: v for k, v in kwargs.items() if v is not None})
|
payload.update({k: v for k, v in kwargs.items() if v is not None})
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def _prepare_params(
|
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
self, kwargs: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Prepare query parameters for API requests.
|
"""Prepare query parameters for API requests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -929,9 +867,7 @@ class AsyncMemoryClient:
|
|||||||
self.user_id = get_user_id()
|
self.user_id = get_user_id()
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError(
|
raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
|
||||||
"Mem0 API Key not provided. Please provide an API Key."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create MD5 hash of API key for user_id
|
# Create MD5 hash of API key for user_id
|
||||||
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
|
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
|
||||||
@@ -989,9 +925,7 @@ class AsyncMemoryClient:
|
|||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
raise ValueError(f"Error: {error_message}")
|
raise ValueError(f"Error: {error_message}")
|
||||||
|
|
||||||
def _prepare_payload(
|
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Prepare the payload for API requests.
|
"""Prepare the payload for API requests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1007,9 +941,7 @@ class AsyncMemoryClient:
|
|||||||
payload.update({k: v for k, v in kwargs.items() if v is not None})
|
payload.update({k: v for k, v in kwargs.items() if v is not None})
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def _prepare_params(
|
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
self, kwargs: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Prepare query parameters for API requests.
|
"""Prepare query parameters for API requests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1041,9 +973,7 @@ class AsyncMemoryClient:
|
|||||||
await self.async_client.aclose()
|
await self.async_client.aclose()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
async def add(
|
async def add(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
||||||
self, messages: List[Dict[str, str]], **kwargs
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
kwargs = self._prepare_params(kwargs)
|
kwargs = self._prepare_params(kwargs)
|
||||||
if kwargs.get("output_format") != "v1.1":
|
if kwargs.get("output_format") != "v1.1":
|
||||||
kwargs["output_format"] = "v1.1"
|
kwargs["output_format"] = "v1.1"
|
||||||
@@ -1062,45 +992,31 @@ class AsyncMemoryClient:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
if "metadata" in kwargs:
|
if "metadata" in kwargs:
|
||||||
del kwargs["metadata"]
|
del kwargs["metadata"]
|
||||||
capture_client_event(
|
capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"})
|
||||||
"client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"}
|
|
||||||
)
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
async def get(self, memory_id: str) -> Dict[str, Any]:
|
async def get(self, memory_id: str) -> Dict[str, Any]:
|
||||||
params = self._prepare_params()
|
params = self._prepare_params()
|
||||||
response = await self.async_client.get(
|
response = await self.async_client.get(f"/v1/memories/{memory_id}/", params=params)
|
||||||
f"/v1/memories/{memory_id}/", params=params
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event(
|
capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "async"})
|
||||||
"client.get", self, {"memory_id": memory_id, "sync_type": "async"}
|
|
||||||
)
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
async def get_all(
|
async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||||
self, version: str = "v1", **kwargs
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
params = self._prepare_params(kwargs)
|
params = self._prepare_params(kwargs)
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
response = await self.async_client.get(
|
response = await self.async_client.get(f"/{version}/memories/", params=params)
|
||||||
f"/{version}/memories/", params=params
|
|
||||||
)
|
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
if "page" in params and "page_size" in params:
|
if "page" in params and "page_size" in params:
|
||||||
query_params = {
|
query_params = {
|
||||||
"page": params.pop("page"),
|
"page": params.pop("page"),
|
||||||
"page_size": params.pop("page_size"),
|
"page_size": params.pop("page_size"),
|
||||||
}
|
}
|
||||||
response = await self.async_client.post(
|
response = await self.async_client.post(f"/{version}/memories/", json=params, params=query_params)
|
||||||
f"/{version}/memories/", json=params, params=query_params
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
response = await self.async_client.post(
|
response = await self.async_client.post(f"/{version}/memories/", json=params)
|
||||||
f"/{version}/memories/", json=params
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
if "metadata" in kwargs:
|
if "metadata" in kwargs:
|
||||||
del kwargs["metadata"]
|
del kwargs["metadata"]
|
||||||
@@ -1116,14 +1032,10 @@ class AsyncMemoryClient:
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
async def search(
|
async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||||
self, query: str, version: str = "v1", **kwargs
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
payload = {"query": query}
|
payload = {"query": query}
|
||||||
payload.update(self._prepare_params(kwargs))
|
payload.update(self._prepare_params(kwargs))
|
||||||
response = await self.async_client.post(
|
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
|
||||||
f"/{version}/memories/search/", json=payload
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
if "metadata" in kwargs:
|
if "metadata" in kwargs:
|
||||||
del kwargs["metadata"]
|
del kwargs["metadata"]
|
||||||
@@ -1139,7 +1051,9 @@ class AsyncMemoryClient:
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
async def update(self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
async def update(
|
||||||
|
self, memory_id: str, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Update a memory by ID.
|
Update a memory by ID.
|
||||||
Args:
|
Args:
|
||||||
@@ -1265,10 +1179,7 @@ class AsyncMemoryClient:
|
|||||||
else:
|
else:
|
||||||
entities = await self.users()
|
entities = await self.users()
|
||||||
# Filter entities based on provided IDs using list comprehension
|
# Filter entities based on provided IDs using list comprehension
|
||||||
to_delete = [
|
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
|
||||||
{"type": entity["type"], "name": entity["name"]}
|
|
||||||
for entity in entities["results"]
|
|
||||||
]
|
|
||||||
|
|
||||||
params = self._prepare_params()
|
params = self._prepare_params()
|
||||||
|
|
||||||
@@ -1277,9 +1188,7 @@ class AsyncMemoryClient:
|
|||||||
|
|
||||||
# Delete entities and check response immediately
|
# Delete entities and check response immediately
|
||||||
for entity in to_delete:
|
for entity in to_delete:
|
||||||
response = await self.async_client.delete(
|
response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
|
||||||
f"/v2/entities/{entity['type']}/{entity['name']}/", params=params
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
capture_client_event(
|
capture_client_event(
|
||||||
@@ -1335,9 +1244,7 @@ class AsyncMemoryClient:
|
|||||||
response = await self.async_client.put("/v1/batch/", json={"memories": memories})
|
response = await self.async_client.put("/v1/batch/", json={"memories": memories})
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
capture_client_event(
|
capture_client_event("client.batch_update", self, {"sync_type": "async"})
|
||||||
"client.batch_update", self, {"sync_type": "async"}
|
|
||||||
)
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
@@ -1355,14 +1262,10 @@ class AsyncMemoryClient:
|
|||||||
Raises:
|
Raises:
|
||||||
APIError: If the API request fails.
|
APIError: If the API request fails.
|
||||||
"""
|
"""
|
||||||
response = await self.async_client.request(
|
response = await self.async_client.request("DELETE", "/v1/batch/", json={"memories": memories})
|
||||||
"DELETE", "/v1/batch/", json={"memories": memories}
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
capture_client_event(
|
capture_client_event("client.batch_delete", self, {"sync_type": "async"})
|
||||||
"client.batch_delete", self, {"sync_type": "async"}
|
|
||||||
)
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
@@ -1614,7 +1517,7 @@ class AsyncMemoryClient:
|
|||||||
|
|
||||||
feedback = feedback.upper() if feedback else None
|
feedback = feedback.upper() if feedback else None
|
||||||
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
||||||
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
|
raise ValueError(f"feedback must be one of {', '.join(VALID_FEEDBACK_VALUES)} or None")
|
||||||
|
|
||||||
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
|
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class MongoDBConfig(BaseModel):
|
|||||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
|
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
|
||||||
mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
|
mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
|
||||||
|
|
||||||
@model_validator(mode='before')
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
allowed_fields = set(cls.model_fields.keys())
|
allowed_fields = set(cls.model_fields.keys())
|
||||||
|
|||||||
@@ -36,6 +36,6 @@ class OpenSearchConfig(BaseModel):
|
|||||||
extra_fields = input_fields - allowed_fields
|
extra_fields = input_fields - allowed_fields
|
||||||
if extra_fields:
|
if extra_fields:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
|
f"Extra fields not allowed: {', '.join(extra_fields)}. Allowed fields: {', '.join(allowed_fields)}"
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|||||||
@@ -36,4 +36,4 @@ class GoogleGenAIEmbedding(EmbeddingBase):
|
|||||||
# Call the embed_content method with the correct parameters
|
# Call the embed_content method with the correct parameters
|
||||||
response = self.client.models.embed_content(model=self.config.model, contents=text, config=config)
|
response = self.client.models.embed_content(model=self.config.model, contents=text, config=config)
|
||||||
|
|
||||||
return response.embeddings[0].values
|
return response.embeddings[0].values
|
||||||
|
|||||||
@@ -92,10 +92,12 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
if response["output"]["message"]["content"]:
|
if response["output"]["message"]["content"]:
|
||||||
for item in response["output"]["message"]["content"]:
|
for item in response["output"]["message"]["content"]:
|
||||||
if "toolUse" in item:
|
if "toolUse" in item:
|
||||||
processed_response["tool_calls"].append({
|
processed_response["tool_calls"].append(
|
||||||
"name": item["toolUse"]["name"],
|
{
|
||||||
"arguments": item["toolUse"]["input"],
|
"name": item["toolUse"]["name"],
|
||||||
})
|
"arguments": item["toolUse"]["input"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return processed_response
|
return processed_response
|
||||||
|
|
||||||
|
|||||||
@@ -165,7 +165,6 @@ class GeminiLLM(LLMBase):
|
|||||||
if system_instruction:
|
if system_instruction:
|
||||||
config_params["system_instruction"] = system_instruction
|
config_params["system_instruction"] = system_instruction
|
||||||
|
|
||||||
|
|
||||||
if response_format is not None and response_format["type"] == "json_object":
|
if response_format is not None and response_format["type"] == "json_object":
|
||||||
config_params["response_mime_type"] = "application/json"
|
config_params["response_mime_type"] = "application/json"
|
||||||
if "schema" in response_format:
|
if "schema" in response_format:
|
||||||
@@ -175,7 +174,6 @@ class GeminiLLM(LLMBase):
|
|||||||
formatted_tools = self._reformat_tools(tools)
|
formatted_tools = self._reformat_tools(tools)
|
||||||
config_params["tools"] = formatted_tools
|
config_params["tools"] = formatted_tools
|
||||||
|
|
||||||
|
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
if tool_choice == "auto":
|
if tool_choice == "auto":
|
||||||
mode = types.FunctionCallingConfigMode.AUTO
|
mode = types.FunctionCallingConfigMode.AUTO
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class SarvamLLM(LLMBase):
|
|||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Sarvam API key is required. Set SARVAM_API_KEY environment variable " "or provide api_key in config."
|
"Sarvam API key is required. Set SARVAM_API_KEY environment variable or provide api_key in config."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set base URL - use config value or environment or default
|
# Set base URL - use config value or environment or default
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from openai import OpenAI
|
|||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
from mem0.memory.utils import extract_json
|
from mem0.memory.utils import extract_json
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
|
|
||||||
class VllmLLM(LLMBase):
|
class VllmLLM(LLMBase):
|
||||||
@@ -41,10 +40,12 @@ class VllmLLM(LLMBase):
|
|||||||
|
|
||||||
if response.choices[0].message.tool_calls:
|
if response.choices[0].message.tool_calls:
|
||||||
for tool_call in response.choices[0].message.tool_calls:
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
processed_response["tool_calls"].append({
|
processed_response["tool_calls"].append(
|
||||||
"name": tool_call.function.name,
|
{
|
||||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
"name": tool_call.function.name,
|
||||||
})
|
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return processed_response
|
return processed_response
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -136,7 +136,6 @@ class MemoryGraph:
|
|||||||
params = {"user_id": filters["user_id"]}
|
params = {"user_id": filters["user_id"]}
|
||||||
self.graph.query(cypher, params=params)
|
self.graph.query(cypher, params=params)
|
||||||
|
|
||||||
|
|
||||||
def get_all(self, filters, limit=100):
|
def get_all(self, filters, limit=100):
|
||||||
"""
|
"""
|
||||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||||
@@ -176,7 +175,6 @@ class MemoryGraph:
|
|||||||
|
|
||||||
return final_results
|
return final_results
|
||||||
|
|
||||||
|
|
||||||
def _retrieve_nodes_from_data(self, data, filters):
|
def _retrieve_nodes_from_data(self, data, filters):
|
||||||
"""Extracts all the entities mentioned in the query."""
|
"""Extracts all the entities mentioned in the query."""
|
||||||
_tools = [EXTRACT_ENTITIES_TOOL]
|
_tools = [EXTRACT_ENTITIES_TOOL]
|
||||||
@@ -213,7 +211,7 @@ class MemoryGraph:
|
|||||||
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
||||||
"""Establish relations among the extracted nodes."""
|
"""Establish relations among the extracted nodes."""
|
||||||
|
|
||||||
# Compose user identification string for prompt
|
# Compose user identification string for prompt
|
||||||
user_identity = f"user_id: {filters['user_id']}"
|
user_identity = f"user_id: {filters['user_id']}"
|
||||||
if filters.get("agent_id"):
|
if filters.get("agent_id"):
|
||||||
user_identity += f", agent_id: {filters['agent_id']}"
|
user_identity += f", agent_id: {filters['agent_id']}"
|
||||||
@@ -221,9 +219,7 @@ class MemoryGraph:
|
|||||||
if self.config.graph_store.custom_prompt:
|
if self.config.graph_store.custom_prompt:
|
||||||
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
|
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
|
||||||
# Add the custom prompt line if configured
|
# Add the custom prompt line if configured
|
||||||
system_content = system_content.replace(
|
system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
|
||||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
|
||||||
)
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": system_content},
|
{"role": "system", "content": system_content},
|
||||||
{"role": "user", "content": data},
|
{"role": "user", "content": data},
|
||||||
@@ -336,7 +332,7 @@ class MemoryGraph:
|
|||||||
user_id = filters["user_id"]
|
user_id = filters["user_id"]
|
||||||
agent_id = filters.get("agent_id", None)
|
agent_id = filters.get("agent_id", None)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for item in to_be_deleted:
|
for item in to_be_deleted:
|
||||||
source = item["source"]
|
source = item["source"]
|
||||||
destination = item["destination"]
|
destination = item["destination"]
|
||||||
@@ -349,7 +345,7 @@ class MemoryGraph:
|
|||||||
"dest_name": destination,
|
"dest_name": destination,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if agent_id:
|
if agent_id:
|
||||||
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||||
params["agent_id"] = agent_id
|
params["agent_id"] = agent_id
|
||||||
@@ -366,10 +362,10 @@ class MemoryGraph:
|
|||||||
m.name AS target,
|
m.name AS target,
|
||||||
type(r) AS relationship
|
type(r) AS relationship
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _add_entities(self, to_be_added, filters, entity_type_map):
|
def _add_entities(self, to_be_added, filters, entity_type_map):
|
||||||
@@ -430,7 +426,7 @@ class MemoryGraph:
|
|||||||
r.mentions = coalesce(r.mentions, 0) + 1
|
r.mentions = coalesce(r.mentions, 0) + 1
|
||||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||||
"destination_name": destination,
|
"destination_name": destination,
|
||||||
@@ -592,7 +588,6 @@ class MemoryGraph:
|
|||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
|
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
|
||||||
agent_filter = ""
|
agent_filter = ""
|
||||||
if filters.get("agent_id"):
|
if filters.get("agent_id"):
|
||||||
|
|||||||
@@ -338,7 +338,7 @@ class Memory(MemoryBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in new_retrieved_facts: {e}")
|
logger.error(f"Error in new_retrieved_facts: {e}")
|
||||||
new_retrieved_facts = []
|
new_retrieved_facts = []
|
||||||
|
|
||||||
if not new_retrieved_facts:
|
if not new_retrieved_facts:
|
||||||
logger.debug("No new facts retrieved from input. Skipping memory update LLM call.")
|
logger.debug("No new facts retrieved from input. Skipping memory update LLM call.")
|
||||||
|
|
||||||
@@ -1166,7 +1166,7 @@ class AsyncMemory(MemoryBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in new_retrieved_facts: {e}")
|
logger.error(f"Error in new_retrieved_facts: {e}")
|
||||||
new_retrieved_facts = []
|
new_retrieved_facts = []
|
||||||
|
|
||||||
if not new_retrieved_facts:
|
if not new_retrieved_facts:
|
||||||
logger.debug("No new facts retrieved from input. Skipping memory update LLM call.")
|
logger.debug("No new facts retrieved from input. Skipping memory update LLM call.")
|
||||||
|
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class MemoryGraph:
|
|||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
params = {"user_id": filters["user_id"], "limit": limit}
|
params = {"user_id": filters["user_id"], "limit": limit}
|
||||||
|
|
||||||
results = self.graph.query(query, params=params)
|
results = self.graph.query(query, params=params)
|
||||||
|
|
||||||
final_results = []
|
final_results = []
|
||||||
@@ -318,7 +318,7 @@ class MemoryGraph:
|
|||||||
"user_id": filters["user_id"],
|
"user_id": filters["user_id"],
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
}
|
}
|
||||||
|
|
||||||
ans = self.graph.query(cypher_query, params=params)
|
ans = self.graph.query(cypher_query, params=params)
|
||||||
result_relations.extend(ans)
|
result_relations.extend(ans)
|
||||||
|
|
||||||
@@ -356,7 +356,7 @@ class MemoryGraph:
|
|||||||
user_id = filters["user_id"]
|
user_id = filters["user_id"]
|
||||||
agent_id = filters.get("agent_id", None)
|
agent_id = filters.get("agent_id", None)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for item in to_be_deleted:
|
for item in to_be_deleted:
|
||||||
source = item["source"]
|
source = item["source"]
|
||||||
destination = item["destination"]
|
destination = item["destination"]
|
||||||
@@ -369,7 +369,7 @@ class MemoryGraph:
|
|||||||
"dest_name": destination,
|
"dest_name": destination,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if agent_id:
|
if agent_id:
|
||||||
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
||||||
params["agent_id"] = agent_id
|
params["agent_id"] = agent_id
|
||||||
@@ -386,10 +386,10 @@ class MemoryGraph:
|
|||||||
m.name AS target,
|
m.name AS target,
|
||||||
type(r) AS relationship
|
type(r) AS relationship
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# added Entity label to all nodes for vector search to work
|
# added Entity label to all nodes for vector search to work
|
||||||
@@ -398,7 +398,7 @@ class MemoryGraph:
|
|||||||
user_id = filters["user_id"]
|
user_id = filters["user_id"]
|
||||||
agent_id = filters.get("agent_id", None)
|
agent_id = filters.get("agent_id", None)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for item in to_be_added:
|
for item in to_be_added:
|
||||||
# entities
|
# entities
|
||||||
source = item["source"]
|
source = item["source"]
|
||||||
@@ -421,7 +421,7 @@ class MemoryGraph:
|
|||||||
agent_id_clause = ""
|
agent_id_clause = ""
|
||||||
if agent_id:
|
if agent_id:
|
||||||
agent_id_clause = ", agent_id: $agent_id"
|
agent_id_clause = ", agent_id: $agent_id"
|
||||||
|
|
||||||
# TODO: Create a cypher query and common params for all the cases
|
# TODO: Create a cypher query and common params for all the cases
|
||||||
if not destination_node_search_result and source_node_search_result:
|
if not destination_node_search_result and source_node_search_result:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
@@ -446,7 +446,7 @@ class MemoryGraph:
|
|||||||
}
|
}
|
||||||
if agent_id:
|
if agent_id:
|
||||||
params["agent_id"] = agent_id
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
elif destination_node_search_result and not source_node_search_result:
|
elif destination_node_search_result and not source_node_search_result:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (destination:Entity)
|
MATCH (destination:Entity)
|
||||||
@@ -470,7 +470,7 @@ class MemoryGraph:
|
|||||||
}
|
}
|
||||||
if agent_id:
|
if agent_id:
|
||||||
params["agent_id"] = agent_id
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
elif source_node_search_result and destination_node_search_result:
|
elif source_node_search_result and destination_node_search_result:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (source:Entity)
|
MATCH (source:Entity)
|
||||||
@@ -490,7 +490,7 @@ class MemoryGraph:
|
|||||||
}
|
}
|
||||||
if agent_id:
|
if agent_id:
|
||||||
params["agent_id"] = agent_id
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
else:
|
else:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
|
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
|
||||||
@@ -512,7 +512,7 @@ class MemoryGraph:
|
|||||||
}
|
}
|
||||||
if agent_id:
|
if agent_id:
|
||||||
params["agent_id"] = agent_id
|
params["agent_id"] = agent_id
|
||||||
|
|
||||||
result = self.graph.query(cypher, params=params)
|
result = self.graph.query(cypher, params=params)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
return results
|
return results
|
||||||
@@ -528,7 +528,7 @@ class MemoryGraph:
|
|||||||
"""Search for source nodes with similar embeddings."""
|
"""Search for source nodes with similar embeddings."""
|
||||||
user_id = filters["user_id"]
|
user_id = filters["user_id"]
|
||||||
agent_id = filters.get("agent_id", None)
|
agent_id = filters.get("agent_id", None)
|
||||||
|
|
||||||
if agent_id:
|
if agent_id:
|
||||||
cypher = """
|
cypher = """
|
||||||
CALL vector_search.search("memzero", 1, $source_embedding)
|
CALL vector_search.search("memzero", 1, $source_embedding)
|
||||||
@@ -567,7 +567,7 @@ class MemoryGraph:
|
|||||||
"""Search for destination nodes with similar embeddings."""
|
"""Search for destination nodes with similar embeddings."""
|
||||||
user_id = filters["user_id"]
|
user_id = filters["user_id"]
|
||||||
agent_id = filters.get("agent_id", None)
|
agent_id = filters.get("agent_id", None)
|
||||||
|
|
||||||
if agent_id:
|
if agent_id:
|
||||||
cypher = """
|
cypher = """
|
||||||
CALL vector_search.search("memzero", 1, $destination_embedding)
|
CALL vector_search.search("memzero", 1, $destination_embedding)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Dict, Any, Callable
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -26,13 +26,7 @@ class MongoDB(VectorStoreBase):
|
|||||||
VECTOR_TYPE = "knnVector"
|
VECTOR_TYPE = "knnVector"
|
||||||
SIMILARITY_METRIC = "cosine"
|
SIMILARITY_METRIC = "cosine"
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_name: str, collection_name: str, embedding_model_dims: int, mongo_uri: str):
|
||||||
self,
|
|
||||||
db_name: str,
|
|
||||||
collection_name: str,
|
|
||||||
embedding_model_dims: int,
|
|
||||||
mongo_uri: str
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Initialize the MongoDB vector store with vector search capabilities.
|
Initialize the MongoDB vector store with vector search capabilities.
|
||||||
|
|
||||||
@@ -46,9 +40,7 @@ class MongoDB(VectorStoreBase):
|
|||||||
self.embedding_model_dims = embedding_model_dims
|
self.embedding_model_dims = embedding_model_dims
|
||||||
self.db_name = db_name
|
self.db_name = db_name
|
||||||
|
|
||||||
self.client = MongoClient(
|
self.client = MongoClient(mongo_uri)
|
||||||
mongo_uri
|
|
||||||
)
|
|
||||||
self.db = self.client[db_name]
|
self.db = self.client[db_name]
|
||||||
self.collection = self.create_col()
|
self.collection = self.create_col()
|
||||||
|
|
||||||
@@ -119,7 +111,9 @@ class MongoDB(VectorStoreBase):
|
|||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
logger.error(f"Error inserting data: {e}")
|
logger.error(f"Error inserting data: {e}")
|
||||||
|
|
||||||
def search(self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]:
|
def search(
|
||||||
|
self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None
|
||||||
|
) -> List[OutputData]:
|
||||||
"""
|
"""
|
||||||
Search for similar vectors using the vector search index.
|
Search for similar vectors using the vector search index.
|
||||||
|
|
||||||
@@ -285,7 +279,7 @@ class MongoDB(VectorStoreBase):
|
|||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
logger.error(f"Error listing documents: {e}")
|
logger.error(f"Error listing documents: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the index by deleting and recreating it."""
|
"""Reset the index by deleting and recreating it."""
|
||||||
logger.warning(f"Resetting index {self.collection_name}...")
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
self.client.indices.create(index=name, body=index_settings)
|
self.client.indices.create(index=name, body=index_settings)
|
||||||
|
|
||||||
# Wait for index to be ready
|
# Wait for index to be ready
|
||||||
max_retries = 180 # 3 minutes timeout
|
max_retries = 180 # 3 minutes timeout
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
while retry_count < max_retries:
|
while retry_count < max_retries:
|
||||||
try:
|
try:
|
||||||
|
|||||||
5797
poetry.lock
generated
5797
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -30,11 +30,13 @@ vector_stores = [
|
|||||||
"vecs>=0.4.0",
|
"vecs>=0.4.0",
|
||||||
"chromadb>=0.4.24",
|
"chromadb>=0.4.24",
|
||||||
"weaviate-client>=4.4.0",
|
"weaviate-client>=4.4.0",
|
||||||
"pinecone<7.0.0",
|
"pinecone<=7.3.0",
|
||||||
"pinecone-text>=0.1.1",
|
"pinecone-text>=0.10.0",
|
||||||
"faiss-cpu>=1.7.4",
|
"faiss-cpu>=1.7.4",
|
||||||
"upstash-vector>=0.1.0",
|
"upstash-vector>=0.1.0",
|
||||||
"azure-search-documents>=11.4.0b8",
|
"azure-search-documents>=11.4.0b8",
|
||||||
|
"pymongo>=4.13.2",
|
||||||
|
"pymochow>=2.2.9",
|
||||||
]
|
]
|
||||||
llms = [
|
llms = [
|
||||||
"groq>=0.3.0",
|
"groq>=0.3.0",
|
||||||
@@ -44,12 +46,11 @@ llms = [
|
|||||||
"vertexai>=0.1.0",
|
"vertexai>=0.1.0",
|
||||||
"google-generativeai>=0.3.0",
|
"google-generativeai>=0.3.0",
|
||||||
"google-genai>=1.0.0",
|
"google-genai>=1.0.0",
|
||||||
|
|
||||||
]
|
]
|
||||||
extras = [
|
extras = [
|
||||||
"boto3>=1.34.0",
|
"boto3>=1.34.0",
|
||||||
"langchain-community>=0.0.0",
|
"langchain-community>=0.0.0",
|
||||||
"sentence-transformers>=2.2.2",
|
"sentence-transformers>=5.0.0",
|
||||||
"elasticsearch>=8.0.0",
|
"elasticsearch>=8.0.0",
|
||||||
"opensearch-py>=2.0.0",
|
"opensearch-py>=2.0.0",
|
||||||
"langchain-memgraph>=0.1.0",
|
"langchain-memgraph>=0.1.0",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch, ANY
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -8,8 +8,10 @@ from mem0.embeddings.gemini import GoogleGenAIEmbedding
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_genai():
|
def mock_genai():
|
||||||
with patch("mem0.embeddings.gemini.genai.embed_content") as mock_genai:
|
with patch("mem0.embeddings.gemini.genai.Client") as mock_client_class:
|
||||||
yield mock_genai
|
mock_client = mock_client_class.return_value
|
||||||
|
mock_client.models.embed_content.return_value = None
|
||||||
|
yield mock_client.models.embed_content
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -18,7 +20,9 @@ def config():
|
|||||||
|
|
||||||
|
|
||||||
def test_embed_query(mock_genai, config):
|
def test_embed_query(mock_genai, config):
|
||||||
mock_embedding_response = {"embedding": [0.1, 0.2, 0.3, 0.4]}
|
mock_embedding_response = type('Response', (), {
|
||||||
|
'embeddings': [type('Embedding', (), {'values': [0.1, 0.2, 0.3, 0.4]})]
|
||||||
|
})()
|
||||||
mock_genai.return_value = mock_embedding_response
|
mock_genai.return_value = mock_embedding_response
|
||||||
|
|
||||||
embedder = GoogleGenAIEmbedding(config)
|
embedder = GoogleGenAIEmbedding(config)
|
||||||
@@ -27,10 +31,11 @@ def test_embed_query(mock_genai, config):
|
|||||||
embedding = embedder.embed(text)
|
embedding = embedder.embed(text)
|
||||||
|
|
||||||
assert embedding == [0.1, 0.2, 0.3, 0.4]
|
assert embedding == [0.1, 0.2, 0.3, 0.4]
|
||||||
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786)
|
mock_genai.assert_called_once_with(model="test_model", contents="Hello, world!", config=ANY)
|
||||||
|
|
||||||
|
|
||||||
def test_embed_returns_empty_list_if_none(mock_genai, config):
|
def test_embed_returns_empty_list_if_none(mock_genai, config):
|
||||||
mock_genai.return_value = None
|
mock_genai.return_value = type('Response', (), {'embeddings': [type('Embedding', (), {'values': []})]})()
|
||||||
|
|
||||||
embedder = GoogleGenAIEmbedding(config)
|
embedder = GoogleGenAIEmbedding(config)
|
||||||
result = embedder.embed("test")
|
result = embedder.embed("test")
|
||||||
@@ -47,10 +52,10 @@ def test_embed_raises_on_error(mock_genai, config):
|
|||||||
with pytest.raises(RuntimeError, match="Embedding failed"):
|
with pytest.raises(RuntimeError, match="Embedding failed"):
|
||||||
embedder.embed("some input")
|
embedder.embed("some input")
|
||||||
|
|
||||||
|
|
||||||
def test_config_initialization(config):
|
def test_config_initialization(config):
|
||||||
embedder = GoogleGenAIEmbedding(config)
|
embedder = GoogleGenAIEmbedding(config)
|
||||||
|
|
||||||
assert embedder.config.api_key == "dummy_api_key"
|
assert embedder.config.api_key == "dummy_api_key"
|
||||||
assert embedder.config.model == "test_model"
|
assert embedder.config.model == "test_model"
|
||||||
assert embedder.config.embedding_dims == 786
|
assert embedder.config.embedding_dims == 786
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from mem0.llms.gemini import GeminiLLM
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_gemini_client():
|
def mock_gemini_client():
|
||||||
with patch("mem0.llms.gemini.genai") as mock_client_class:
|
with patch("mem0.llms.gemini.genai.Client") as mock_client_class:
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
yield mock_client
|
yield mock_client
|
||||||
@@ -24,43 +24,30 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
|
|||||||
]
|
]
|
||||||
|
|
||||||
mock_part = Mock(text="I'm doing well, thank you for asking!")
|
mock_part = Mock(text="I'm doing well, thank you for asking!")
|
||||||
mock_embedding = Mock()
|
mock_content = Mock(parts=[mock_part])
|
||||||
mock_embedding.values = [0.1, 0.2, 0.3]
|
mock_candidate = Mock(content=mock_content)
|
||||||
|
mock_response = Mock(candidates=[mock_candidate])
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.candidates = [Mock()]
|
|
||||||
mock_response.candidates[0].content.parts = [Mock()]
|
|
||||||
mock_response.candidates[0].content.parts[0].text = "I'm doing well, thank you for asking!"
|
|
||||||
|
|
||||||
mock_gemini_client.models.generate_content.return_value = mock_response
|
mock_gemini_client.models.generate_content.return_value = mock_response
|
||||||
mock_content = Mock(parts=[mock_part])
|
|
||||||
mock_message = Mock(content=mock_content)
|
|
||||||
mock_response = Mock(candidates=[mock_message])
|
|
||||||
mock_gemini_client.generate_content.return_value = mock_response
|
|
||||||
|
|
||||||
response = llm.generate_response(messages)
|
response = llm.generate_response(messages)
|
||||||
|
|
||||||
mock_gemini_client.generate_content.assert_called_once_with(
|
# Check the actual call - system instruction is now in config
|
||||||
contents=[
|
mock_gemini_client.models.generate_content.assert_called_once()
|
||||||
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
|
call_args = mock_gemini_client.models.generate_content.call_args
|
||||||
{"parts": "Hello, how are you?", "role": "user"},
|
|
||||||
],
|
|
||||||
config=types.GenerateContentConfig(
|
|
||||||
temperature=0.7,
|
|
||||||
max_output_tokens=100,
|
|
||||||
top_p=1.0,
|
|
||||||
tools=None,
|
|
||||||
tool_config=types.ToolConfig(
|
|
||||||
function_calling_config=types.FunctionCallingConfig(
|
|
||||||
allowed_function_names=None,
|
|
||||||
mode="auto"
|
|
||||||
|
|
||||||
)
|
|
||||||
)
|
|
||||||
) )
|
|
||||||
|
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
|
||||||
|
|
||||||
|
# Verify model and contents
|
||||||
|
assert call_args.kwargs['model'] == "gemini-2.0-flash-latest"
|
||||||
|
assert len(call_args.kwargs['contents']) == 1 # Only user message
|
||||||
|
|
||||||
|
# Verify config has system instruction
|
||||||
|
config_arg = call_args.kwargs['config']
|
||||||
|
assert config_arg.system_instruction == "You are a helpful assistant."
|
||||||
|
assert config_arg.temperature == 0.7
|
||||||
|
assert config_arg.max_output_tokens == 100
|
||||||
|
assert config_arg.top_p == 1.0
|
||||||
|
|
||||||
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response_with_tools(mock_gemini_client: Mock):
|
def test_generate_response_with_tools(mock_gemini_client: Mock):
|
||||||
@@ -89,64 +76,42 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
|
|||||||
mock_tool_call.name = "add_memory"
|
mock_tool_call.name = "add_memory"
|
||||||
mock_tool_call.args = {"data": "Today is a sunny day."}
|
mock_tool_call.args = {"data": "Today is a sunny day."}
|
||||||
|
|
||||||
mock_part = Mock()
|
# Create mock parts with both text and function_call
|
||||||
mock_part.function_call = mock_tool_call
|
mock_text_part = Mock()
|
||||||
mock_part.text = "I've added the memory for you."
|
mock_text_part.text = "I've added the memory for you."
|
||||||
|
mock_text_part.function_call = None
|
||||||
|
|
||||||
|
mock_func_part = Mock()
|
||||||
|
mock_func_part.text = None
|
||||||
|
mock_func_part.function_call = mock_tool_call
|
||||||
|
|
||||||
mock_content = Mock()
|
mock_content = Mock()
|
||||||
mock_content.parts = [mock_part]
|
mock_content.parts = [mock_text_part, mock_func_part]
|
||||||
|
|
||||||
mock_message = Mock()
|
mock_candidate = Mock()
|
||||||
mock_message.content = mock_content
|
mock_candidate.content = mock_content
|
||||||
|
|
||||||
mock_response = Mock(candidates=[mock_message])
|
mock_response = Mock(candidates=[mock_candidate])
|
||||||
mock_gemini_client.generate_content.return_value = mock_response
|
mock_gemini_client.models.generate_content.return_value = mock_response
|
||||||
|
|
||||||
response = llm.generate_response(messages, tools=tools)
|
response = llm.generate_response(messages, tools=tools)
|
||||||
|
|
||||||
mock_gemini_client.generate_content.assert_called_once_with(
|
# Check the actual call
|
||||||
contents=[
|
mock_gemini_client.models.generate_content.assert_called_once()
|
||||||
{
|
call_args = mock_gemini_client.models.generate_content.call_args
|
||||||
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
|
|
||||||
"role": "user"
|
# Verify model and contents
|
||||||
},
|
assert call_args.kwargs['model'] == "gemini-1.5-flash-latest"
|
||||||
{
|
assert len(call_args.kwargs['contents']) == 1 # Only user message
|
||||||
"parts": "Add a new memory: Today is a sunny day.",
|
|
||||||
"role": "user"
|
# Verify config has system instruction and tools
|
||||||
},
|
config_arg = call_args.kwargs['config']
|
||||||
],
|
assert config_arg.system_instruction == "You are a helpful assistant."
|
||||||
config=types.GenerateContentConfig(
|
assert config_arg.temperature == 0.7
|
||||||
temperature=0.7,
|
assert config_arg.max_output_tokens == 100
|
||||||
max_output_tokens=100,
|
assert config_arg.top_p == 1.0
|
||||||
top_p=1.0,
|
assert len(config_arg.tools) == 1
|
||||||
tools=[
|
assert config_arg.tool_config.function_calling_config.mode == types.FunctionCallingConfigMode.AUTO
|
||||||
types.Tool(
|
|
||||||
function_declarations=[
|
|
||||||
types.FunctionDeclaration(
|
|
||||||
name="add_memory",
|
|
||||||
description="Add a memory",
|
|
||||||
parameters={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"data": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Data to add to memory"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["data"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
],
|
|
||||||
tool_config=types.ToolConfig(
|
|
||||||
function_calling_config=types.FunctionCallingConfig(
|
|
||||||
allowed_function_names=None,
|
|
||||||
mode="auto"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response["content"] == "I've added the memory for you."
|
assert response["content"] == "I've added the memory for you."
|
||||||
assert len(response["tool_calls"]) == 1
|
assert len(response["tool_calls"]) == 1
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ def test_generate_response_without_tools(mock_lm_studio_client):
|
|||||||
|
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response_specifying_response_format(mock_lm_studio_client):
|
def test_generate_response_specifying_response_format(mock_lm_studio_client):
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(
|
||||||
model="lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
|
model="lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
|
||||||
@@ -68,4 +69,4 @@ def test_generate_response_specifying_response_format(mock_lm_studio_client):
|
|||||||
response_format={"type": "json_schema"},
|
response_format={"type": "json_schema"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|||||||
@@ -71,7 +71,13 @@ def test_generate_response_with_tools(mock_vllm_client):
|
|||||||
response = llm.generate_response(messages, tools=tools)
|
response = llm.generate_response(messages, tools=tools)
|
||||||
|
|
||||||
mock_vllm_client.chat.completions.create.assert_called_once_with(
|
mock_vllm_client.chat.completions.create.assert_called_once_with(
|
||||||
model="Qwen/Qwen2.5-32B-Instruct", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto"
|
model="Qwen/Qwen2.5-32B-Instruct",
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
top_p=1.0,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response["content"] == "I've added the memory for you."
|
assert response["content"] == "I've added the memory for you."
|
||||||
|
|||||||
@@ -253,10 +253,10 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
|
|||||||
def test_custom_prompts(memory_custom_instance):
|
def test_custom_prompts(memory_custom_instance):
|
||||||
messages = [{"role": "user", "content": "Test message"}]
|
messages = [{"role": "user", "content": "Test message"}]
|
||||||
from mem0.embeddings.mock import MockEmbeddings
|
from mem0.embeddings.mock import MockEmbeddings
|
||||||
|
|
||||||
memory_custom_instance.llm.generate_response = Mock()
|
memory_custom_instance.llm.generate_response = Mock()
|
||||||
memory_custom_instance.llm.generate_response.return_value = '{"facts": ["fact1", "fact2"]}'
|
memory_custom_instance.llm.generate_response.return_value = '{"facts": ["fact1", "fact2"]}'
|
||||||
memory_custom_instance.embedding_model = MockEmbeddings()
|
memory_custom_instance.embedding_model = MockEmbeddings()
|
||||||
|
|
||||||
|
|
||||||
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
|
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
|
||||||
with patch(
|
with patch(
|
||||||
|
|||||||
@@ -2,10 +2,9 @@ from unittest.mock import Mock, patch, PropertyMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mem0.vector_stores.baidu import BaiduDB, OutputData
|
from mem0.vector_stores.baidu import BaiduDB
|
||||||
from pymochow.model.enum import MetricType, TableState, ServerErrCode
|
from pymochow.model.enum import TableState, ServerErrCode
|
||||||
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
|
from pymochow.model.table import VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
|
||||||
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
|
|
||||||
from pymochow.exception import ServerError
|
from pymochow.exception import ServerError
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from mem0.vector_stores.mongodb import MongoDB
|
from mem0.vector_stores.mongodb import MongoDB
|
||||||
from pymongo.operations import SearchIndexModel
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@patch("mem0.vector_stores.mongodb.MongoClient")
|
@patch("mem0.vector_stores.mongodb.MongoClient")
|
||||||
@@ -19,10 +19,11 @@ def mongo_vector_fixture(mock_mongo_client):
|
|||||||
db_name="test_db",
|
db_name="test_db",
|
||||||
collection_name="test_collection",
|
collection_name="test_collection",
|
||||||
embedding_model_dims=1536,
|
embedding_model_dims=1536,
|
||||||
mongo_uri="mongodb://username:password@localhost:27017"
|
mongo_uri="mongodb://username:password@localhost:27017",
|
||||||
)
|
)
|
||||||
return mongo_vector, mock_collection, mock_db
|
return mongo_vector, mock_collection, mock_db
|
||||||
|
|
||||||
|
|
||||||
def test_initalize_create_col(mongo_vector_fixture):
|
def test_initalize_create_col(mongo_vector_fixture):
|
||||||
mongo_vector, mock_collection, mock_db = mongo_vector_fixture
|
mongo_vector, mock_collection, mock_db = mongo_vector_fixture
|
||||||
assert mongo_vector.collection_name == "test_collection"
|
assert mongo_vector.collection_name == "test_collection"
|
||||||
@@ -49,12 +50,13 @@ def test_initalize_create_col(mongo_vector_fixture):
|
|||||||
"dimensions": 1536,
|
"dimensions": 1536,
|
||||||
"similarity": "cosine",
|
"similarity": "cosine",
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
assert mongo_vector.collection == mock_collection
|
assert mongo_vector.collection == mock_collection
|
||||||
|
|
||||||
|
|
||||||
def test_insert(mongo_vector_fixture):
|
def test_insert(mongo_vector_fixture):
|
||||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||||
vectors = [[0.1] * 1536, [0.2] * 1536]
|
vectors = [[0.1] * 1536, [0.2] * 1536]
|
||||||
@@ -62,12 +64,13 @@ def test_insert(mongo_vector_fixture):
|
|||||||
ids = ["id1", "id2"]
|
ids = ["id1", "id2"]
|
||||||
|
|
||||||
mongo_vector.insert(vectors, payloads, ids)
|
mongo_vector.insert(vectors, payloads, ids)
|
||||||
expected_records=[
|
expected_records = [
|
||||||
({"_id": ids[0], "embedding": vectors[0], "payload": payloads[0]}),
|
({"_id": ids[0], "embedding": vectors[0], "payload": payloads[0]}),
|
||||||
({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]})
|
({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]}),
|
||||||
]
|
]
|
||||||
mock_collection.insert_many.assert_called_once_with(expected_records)
|
mock_collection.insert_many.assert_called_once_with(expected_records)
|
||||||
|
|
||||||
|
|
||||||
def test_search(mongo_vector_fixture):
|
def test_search(mongo_vector_fixture):
|
||||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||||
query_vector = [0.1] * 1536
|
query_vector = [0.1] * 1536
|
||||||
@@ -79,25 +82,28 @@ def test_search(mongo_vector_fixture):
|
|||||||
|
|
||||||
results = mongo_vector.search("query_str", query_vector, limit=2)
|
results = mongo_vector.search("query_str", query_vector, limit=2)
|
||||||
mock_collection.list_search_indexes.assert_called_with(name="test_collection_vector_index")
|
mock_collection.list_search_indexes.assert_called_with(name="test_collection_vector_index")
|
||||||
mock_collection.aggregate.assert_called_once_with([
|
mock_collection.aggregate.assert_called_once_with(
|
||||||
{
|
[
|
||||||
"$vectorSearch": {
|
{
|
||||||
"index": "test_collection_vector_index",
|
"$vectorSearch": {
|
||||||
"limit": 2,
|
"index": "test_collection_vector_index",
|
||||||
"numCandidates": 2,
|
"limit": 2,
|
||||||
"queryVector": query_vector,
|
"numCandidates": 2,
|
||||||
"path": "embedding",
|
"queryVector": query_vector,
|
||||||
|
"path": "embedding",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
{"$project": {"embedding": 0}},
|
||||||
{"$project": {"embedding": 0}},
|
]
|
||||||
])
|
)
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
assert results[0].id == "id1"
|
assert results[0].id == "id1"
|
||||||
assert results[0].score == 0.9
|
assert results[0].score == 0.9
|
||||||
assert results[1].id == "id2"
|
assert results[1].id == "id2"
|
||||||
assert results[1].score == 0.8
|
assert results[1].score == 0.8
|
||||||
|
|
||||||
|
|
||||||
def test_delete(mongo_vector_fixture):
|
def test_delete(mongo_vector_fixture):
|
||||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||||
mock_delete_result = MagicMock()
|
mock_delete_result = MagicMock()
|
||||||
@@ -107,6 +113,7 @@ def test_delete(mongo_vector_fixture):
|
|||||||
mongo_vector.delete("id1")
|
mongo_vector.delete("id1")
|
||||||
mock_collection.delete_one.assert_called_with({"_id": "id1"})
|
mock_collection.delete_one.assert_called_with({"_id": "id1"})
|
||||||
|
|
||||||
|
|
||||||
def test_update(mongo_vector_fixture):
|
def test_update(mongo_vector_fixture):
|
||||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||||
mock_update_result = MagicMock()
|
mock_update_result = MagicMock()
|
||||||
@@ -122,6 +129,7 @@ def test_update(mongo_vector_fixture):
|
|||||||
{"$set": {"embedding": vectorValue, "payload": payloadValue}},
|
{"$set": {"embedding": vectorValue, "payload": payloadValue}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get(mongo_vector_fixture):
|
def test_get(mongo_vector_fixture):
|
||||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||||
mock_collection.find_one.return_value = {"_id": "id1", "payload": {"key": "value1"}}
|
mock_collection.find_one.return_value = {"_id": "id1", "payload": {"key": "value1"}}
|
||||||
@@ -131,6 +139,7 @@ def test_get(mongo_vector_fixture):
|
|||||||
assert result.id == "id1"
|
assert result.id == "id1"
|
||||||
assert result.payload == {"key": "value1"}
|
assert result.payload == {"key": "value1"}
|
||||||
|
|
||||||
|
|
||||||
def test_list_cols(mongo_vector_fixture):
|
def test_list_cols(mongo_vector_fixture):
|
||||||
mongo_vector, _, mock_db = mongo_vector_fixture
|
mongo_vector, _, mock_db = mongo_vector_fixture
|
||||||
mock_db.list_collection_names.return_value = ["col1", "col2"]
|
mock_db.list_collection_names.return_value = ["col1", "col2"]
|
||||||
@@ -138,12 +147,14 @@ def test_list_cols(mongo_vector_fixture):
|
|||||||
collections = mongo_vector.list_cols()
|
collections = mongo_vector.list_cols()
|
||||||
assert collections == ["col1", "col2"]
|
assert collections == ["col1", "col2"]
|
||||||
|
|
||||||
|
|
||||||
def test_delete_col(mongo_vector_fixture):
|
def test_delete_col(mongo_vector_fixture):
|
||||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||||
|
|
||||||
mongo_vector.delete_col()
|
mongo_vector.delete_col()
|
||||||
mock_collection.drop.assert_called_once()
|
mock_collection.drop.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_col_info(mongo_vector_fixture):
|
def test_col_info(mongo_vector_fixture):
|
||||||
mongo_vector, _, mock_db = mongo_vector_fixture
|
mongo_vector, _, mock_db = mongo_vector_fixture
|
||||||
mock_db.command.return_value = {"count": 10, "size": 1024}
|
mock_db.command.return_value = {"count": 10, "size": 1024}
|
||||||
@@ -154,6 +165,7 @@ def test_col_info(mongo_vector_fixture):
|
|||||||
assert info["count"] == 10
|
assert info["count"] == 10
|
||||||
assert info["size"] == 1024
|
assert info["size"] == 1024
|
||||||
|
|
||||||
|
|
||||||
def test_list(mongo_vector_fixture):
|
def test_list(mongo_vector_fixture):
|
||||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||||
mock_cursor = MagicMock()
|
mock_cursor = MagicMock()
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ def test_get_vector_found(pinecone_db):
|
|||||||
# or a list of dictionaries, not a dictionary with an 'id' field
|
# or a list of dictionaries, not a dictionary with an 'id' field
|
||||||
|
|
||||||
# Create a mock Vector object
|
# Create a mock Vector object
|
||||||
from pinecone.data.dataclasses.vector import Vector
|
from pinecone import Vector
|
||||||
|
|
||||||
mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"})
|
mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user