diff --git a/Makefile b/Makefile
index 76a25de1..965a7193 100644
--- a/Makefile
+++ b/Makefile
@@ -16,15 +16,15 @@ install_all:
# Format code with ruff
format:
- poetry run ruff check . --fix $(RUFF_OPTIONS)
+ poetry run ruff format mem0/
# Sort imports with isort
sort:
- poetry run isort . $(ISORT_OPTIONS)
+ poetry run isort mem0/
# Lint code with ruff
lint:
- poetry run ruff .
+ poetry run ruff check mem0/
docs:
cd docs && mintlify dev
diff --git a/cookbooks/add_memory_using_qdrant_cloud.py b/cookbooks/add_memory_using_qdrant_cloud.py
index d7142752..0ca02e52 100644
--- a/cookbooks/add_memory_using_qdrant_cloud.py
+++ b/cookbooks/add_memory_using_qdrant_cloud.py
@@ -7,27 +7,21 @@ from mem0 import Memory
# Loading OpenAI API Key
load_dotenv()
-OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
+OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
USER_ID = "test"
-quadrant_host="xx.gcp.cloud.qdrant.io"
+quadrant_host = "xx.gcp.cloud.qdrant.io"
# creating the config attributes
-collection_name="memory" # this is the collection I created in QDRANT cloud
-api_key=os.environ.get("QDRANT_API_KEY") # Getting the QDRANT api KEY
-host=quadrant_host
-port=6333 #Default port for QDRANT cloud
+collection_name = "memory" # this is the collection I created in QDRANT cloud
+api_key = os.environ.get("QDRANT_API_KEY") # Getting the QDRANT api KEY
+host = quadrant_host
+port = 6333 # Default port for QDRANT cloud
# Creating the config dict
config = {
"vector_store": {
"provider": "qdrant",
- "config": {
- "collection_name": collection_name,
- "host": host,
- "port": port,
- "path": None,
- "api_key":api_key
- }
+ "config": {"collection_name": collection_name, "host": host, "port": port, "path": None, "api_key": api_key},
}
}
diff --git a/cookbooks/mem0-multion.ipynb b/cookbooks/mem0-multion.ipynb
index 3cd3fc97..98e30456 100644
--- a/cookbooks/mem0-multion.ipynb
+++ b/cookbooks/mem0-multion.ipynb
@@ -1,189 +1,189 @@
{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "y4bKPPa7DXNs"
- },
- "outputs": [],
- "source": [
- "%pip install mem0ai multion"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pe4htqUmDdmS"
- },
- "source": [
- "## Setup and Configuration\n",
- "\n",
- "First, we'll import the necessary libraries and set up our configurations.\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "id": "fsZwK7eLDh3I"
- },
- "outputs": [],
- "source": [
- "import os\n",
- "from mem0 import Memory\n",
- "from multion.client import MultiOn\n",
- "\n",
- "# Configuration\n",
- "OPENAI_API_KEY = 'sk-xxx' # Replace with your actual OpenAI API key\n",
- "MULTION_API_KEY = 'your-multion-key' # Replace with your actual MultiOn API key\n",
- "USER_ID = \"deshraj\"\n",
- "\n",
- "# Set up OpenAI API key\n",
- "os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY\n",
- "\n",
- "# Initialize Mem0 and MultiOn\n",
- "memory = Memory()\n",
- "multion = MultiOn(api_key=MULTION_API_KEY)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "HTGVhGwaDl-1"
- },
- "source": [
- "## Add memories to Mem0\n",
- "\n",
- "Next, we'll define our user data and add it to Mem0."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "xB3tm0_pDm6e",
- "outputId": "aeab370c-8679-4d39-faaa-f702146d2fc4"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "User data added to memory.\n"
- ]
- }
- ],
- "source": [
- "# Define user data\n",
- "USER_DATA = \"\"\"\n",
- "About me\n",
- "- I'm Deshraj Yadav, Co-founder and CTO at Mem0 (f.k.a Embedchain). I am broadly interested in the field of Artificial Intelligence and Machine Learning Infrastructure.\n",
- "- Previously, I was Senior Autopilot Engineer at Tesla Autopilot where I led the Autopilot's AI Platform which helped the Tesla Autopilot team to track large scale training and model evaluation experiments, provide monitoring and observability into jobs and training cluster issues.\n",
- "- I had built EvalAI as my masters thesis at Georgia Tech, which is an open-source platform for evaluating and comparing machine learning and artificial intelligence algorithms at scale.\n",
- "- Outside of work, I am very much into cricket and play in two leagues (Cricbay and NACL) in San Francisco Bay Area.\n",
- "\"\"\"\n",
- "\n",
- "# Add user data to memory\n",
- "memory.add(USER_DATA, user_id=USER_ID)\n",
- "print(\"User data added to memory.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ZCPUJf0TDqUK"
- },
- "source": [
- "## Retrieving Relevant Memories\n",
- "\n",
- "Now, we'll define our search command and retrieve relevant memories from Mem0."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "s0PwAhNVDrIv",
- "outputId": "59cbb767-b468-4139-8d0c-fa763918dbb0"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Relevant memories:\n",
- "Name: Deshraj Yadav - Co-founder and CTO at Mem0 (formerly known as Embedchain) - Interested in Artificial Intelligence and Machine Learning Infrastructure - Previous role: Senior Autopilot Engineer at Tesla Autopilot - Led the Autopilot's AI Platform at Tesla, focusing on large scale training, model evaluation, monitoring, and observability - Built EvalAI as a master's thesis at Georgia Tech, an open-source platform for evaluating and comparing machine learning algorithms - Enjoys cricket - Plays in two cricket leagues: Cricbay and NACL in the San Francisco Bay Area\n"
- ]
- }
- ],
- "source": [
- "# Define search command and retrieve relevant memories\n",
- "command = \"Find papers on arxiv that I should read based on my interests.\"\n",
- "\n",
- "relevant_memories = memory.search(command, user_id=USER_ID, limit=3)\n",
- "relevant_memories_text = '\\n'.join(mem['memory'] for mem in relevant_memories)\n",
- "print(f\"Relevant memories:\")\n",
- "print(relevant_memories_text)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "jdge78_VDtgv"
- },
- "source": [
- "## Browsing arXiv\n",
- "\n",
- "Finally, we'll use MultiOn to browse arXiv based on our command and relevant memories."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "4T_tLURTDvS-",
- "outputId": "259ff32f-5d42-44e6-f2ef-c3557a8e9da6"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "message=\"Summarizing the relevant papers found so far that align with Deshraj Yadav's interests in Artificial Intelligence and Machine Learning Infrastructure.\\n\\n1. **Urban Waterlogging Detection: A Challenging Benchmark and Large-Small Model Co-Adapter**\\n - Authors: Suqi Song, Chenxu Zhang, Peng Zhang, Pengkun Li, Fenglong Song, Lei Zhang\\n - Abstract: Urban waterlogging poses a major risk to public safety. Conventional methods using water-level sensors need high-maintenance to hardly achieve full coverage. Recent advances employ surveillance camera imagery and deep learning for detection, yet these struggle amidst scarce data and adverse environments.\\n - Date: 10 July, 2024\\n\\n2. **Intercepting Unauthorized Aerial Robots in Controlled Airspace Using Reinforcement Learning**\\n - Authors: Francisco Giral, Ignacio Gómez, Soledad Le Clainche\\n - Abstract: Ensuring the safe and efficient operation of airspace, particularly in urban environments and near critical infrastructure, necessitates effective methods to intercept unauthorized or non-cooperative UAVs. This work addresses the critical need for robust, adaptive systems capable of managing such scenarios.\\n - Date: 9 July, 2024\\n\\n3. **Efficient Materials Informatics between Rockets and Electrons**\\n - Authors: Adam M. Krajewski\\n - Abstract: This paper discusses the distinct efforts existing at three general scales of abstractions of what a material is - atomistic, physical, and design. At each, an efficient materials informatics is being built from the ground up based on the fundamental understanding of the underlying prior knowledge, including the data.\\n - Date: 5 July, 2024\\n\\n4. **ObfuscaTune: Obfuscated Offsite Fine-tuning and Inference of Proprietary LLMs on Private Datasets**\\n - Authors: Ahmed Frikha, Nassim Walha, Ricardo Mendes, Krishna Kanth Nakka, Xue Jiang, Xuebing Zhou\\n - Abstract: This paper proposes ObfuscaTune, a novel, efficient, and fully utility-preserving approach that combines a simple yet effective method to ensure the confidentiality of both the model and the data during offsite fine-tuning on a third-party cloud provider.\\n - Date: 3 July, 2024\\n\\n5. **MG-Verilog: Multi-grained Dataset Towards Enhanced LLM-assisted Verilog Generation**\\n - Authors: Yongan Zhang, Zhongzhi Yu, Yonggan Fu, Cheng Wan, Yingyan Celine Lin\\n - Abstract: This paper discusses the necessity of providing domain-specific data during inference, fine-tuning, or pre-training to effectively leverage LLMs in hardware design. Existing publicly available hardware datasets are often limited in size, complexity, or detail, which hinders the effectiveness of LLMs in this domain.\\n - Date: 1 July, 2024\\n\\n6. **The Future of Aerial Communications: A Survey of IRS-Enhanced UAV Communication Technologies**\\n - Authors: Zina Chkirbene, Ala Gouissem, Ridha Hamila, Devrim Unal\\n - Abstract: The advent of Reflecting Surfaces (IRS) and Unmanned Aerial Vehicles (UAVs) is setting a new benchmark in the field of wireless communications. IRS, with their groundbreaking ability to manipulate electromagnetic waves, have opened avenues for substantial enhancements in signal quality, network efficiency, and spectral usage.\\n - Date: 2 June, 2024\\n\\n7. **Scalable and RISC-V Programmable Near-Memory Computing Architectures for Edge Nodes**\\n - Authors: Michele Caon, Clément Choné, Pasquale Davide Schiavone, Alexandre Levisse, Guido Masera, Maurizio Martina, David Atienza\\n - Abstract: The widespread adoption of data-centric algorithms, particularly AI and ML, has exposed the limitations of centralized processing, driving the need for scalable and programmable near-memory computing architectures for edge nodes.\\n - Date: 20 June, 2024\\n\\n8. **Enhancing robustness of data-driven SHM models: adversarial training with circle loss**\\n - Authors: Xiangli Yang, Xijie Deng, Hanwei Zhang, Yang Zou, Jianxi Yang\\n - Abstract: Structural health monitoring (SHM) is critical to safeguarding the safety and reliability of aerospace, civil, and mechanical infrastructures. This paper discusses the use of adversarial training with circle loss to enhance the robustness of data-driven SHM models.\\n - Date: 20 June, 2024\\n\\n9. **Understanding Pedestrian Movement Using Urban Sensing Technologies: The Promise of Audio-based Sensors**\\n - Authors: Chaeyeon Han, Pavan Seshadri, Yiwei Ding, Noah Posner, Bon Woo Koo, Animesh Agrawal, Alexander Lerch, Subhrajit Guhathakurta\\n - Abstract: Understanding pedestrian volumes and flows is essential for designing safer and more attractive pedestrian infrastructures. This study discusses a new approach to scale up urban sensing of people with the help of novel audio-based technology.\\n - Date: 14 June, 2024\\n\\nASK_USER_HELP: Deshraj, I have found several papers that might be of interest to you. Would you like to proceed with any specific papers from the list above, or should I refine the search further?\\n\" status='NOT_SURE' url='https://arxiv.org/search/?query=Artificial+Intelligence+Machine+Learning+Infrastructure&searchtype=all&source=header' screenshot='' session_id='ff2ee9ef-60d4-4436-bc36-a81d94e0f410' metadata=Metadata(step_count=9, processing_time=66, temperature=0.2)\n"
- ]
- }
- ],
- "source": [
- "# Create prompt and browse arXiv\n",
- "prompt = f\"{command}\\n My past memories: {relevant_memories_text}\"\n",
- "browse_result = multion.browse(cmd=prompt, url=\"https://arxiv.org/\")\n",
- "print(browse_result)"
- ]
- }
- ],
- "metadata": {
- "colab": {
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "name": "python"
- }
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "y4bKPPa7DXNs"
+ },
+ "outputs": [],
+ "source": [
+ "%pip install mem0ai multion"
+ ]
},
- "nbformat": 4,
- "nbformat_minor": 0
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "pe4htqUmDdmS"
+ },
+ "source": [
+ "## Setup and Configuration\n",
+ "\n",
+ "First, we'll import the necessary libraries and set up our configurations.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "fsZwK7eLDh3I"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from mem0 import Memory\n",
+ "from multion.client import MultiOn\n",
+ "\n",
+ "# Configuration\n",
+ "OPENAI_API_KEY = \"sk-xxx\" # Replace with your actual OpenAI API key\n",
+ "MULTION_API_KEY = \"your-multion-key\" # Replace with your actual MultiOn API key\n",
+ "USER_ID = \"deshraj\"\n",
+ "\n",
+ "# Set up OpenAI API key\n",
+ "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n",
+ "\n",
+ "# Initialize Mem0 and MultiOn\n",
+ "memory = Memory()\n",
+ "multion = MultiOn(api_key=MULTION_API_KEY)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HTGVhGwaDl-1"
+ },
+ "source": [
+ "## Add memories to Mem0\n",
+ "\n",
+ "Next, we'll define our user data and add it to Mem0."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "xB3tm0_pDm6e",
+ "outputId": "aeab370c-8679-4d39-faaa-f702146d2fc4"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "User data added to memory.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Define user data\n",
+ "USER_DATA = \"\"\"\n",
+ "About me\n",
+ "- I'm Deshraj Yadav, Co-founder and CTO at Mem0 (f.k.a Embedchain). I am broadly interested in the field of Artificial Intelligence and Machine Learning Infrastructure.\n",
+ "- Previously, I was Senior Autopilot Engineer at Tesla Autopilot where I led the Autopilot's AI Platform which helped the Tesla Autopilot team to track large scale training and model evaluation experiments, provide monitoring and observability into jobs and training cluster issues.\n",
+ "- I had built EvalAI as my masters thesis at Georgia Tech, which is an open-source platform for evaluating and comparing machine learning and artificial intelligence algorithms at scale.\n",
+ "- Outside of work, I am very much into cricket and play in two leagues (Cricbay and NACL) in San Francisco Bay Area.\n",
+ "\"\"\"\n",
+ "\n",
+ "# Add user data to memory\n",
+ "memory.add(USER_DATA, user_id=USER_ID)\n",
+ "print(\"User data added to memory.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZCPUJf0TDqUK"
+ },
+ "source": [
+ "## Retrieving Relevant Memories\n",
+ "\n",
+ "Now, we'll define our search command and retrieve relevant memories from Mem0."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "s0PwAhNVDrIv",
+ "outputId": "59cbb767-b468-4139-8d0c-fa763918dbb0"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Relevant memories:\n",
+ "Name: Deshraj Yadav - Co-founder and CTO at Mem0 (formerly known as Embedchain) - Interested in Artificial Intelligence and Machine Learning Infrastructure - Previous role: Senior Autopilot Engineer at Tesla Autopilot - Led the Autopilot's AI Platform at Tesla, focusing on large scale training, model evaluation, monitoring, and observability - Built EvalAI as a master's thesis at Georgia Tech, an open-source platform for evaluating and comparing machine learning algorithms - Enjoys cricket - Plays in two cricket leagues: Cricbay and NACL in the San Francisco Bay Area\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Define search command and retrieve relevant memories\n",
+ "command = \"Find papers on arxiv that I should read based on my interests.\"\n",
+ "\n",
+ "relevant_memories = memory.search(command, user_id=USER_ID, limit=3)\n",
+ "relevant_memories_text = \"\\n\".join(mem[\"memory\"] for mem in relevant_memories)\n",
+ "print(f\"Relevant memories:\")\n",
+ "print(relevant_memories_text)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jdge78_VDtgv"
+ },
+ "source": [
+ "## Browsing arXiv\n",
+ "\n",
+ "Finally, we'll use MultiOn to browse arXiv based on our command and relevant memories."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "4T_tLURTDvS-",
+ "outputId": "259ff32f-5d42-44e6-f2ef-c3557a8e9da6"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "message=\"Summarizing the relevant papers found so far that align with Deshraj Yadav's interests in Artificial Intelligence and Machine Learning Infrastructure.\\n\\n1. **Urban Waterlogging Detection: A Challenging Benchmark and Large-Small Model Co-Adapter**\\n - Authors: Suqi Song, Chenxu Zhang, Peng Zhang, Pengkun Li, Fenglong Song, Lei Zhang\\n - Abstract: Urban waterlogging poses a major risk to public safety. Conventional methods using water-level sensors need high-maintenance to hardly achieve full coverage. Recent advances employ surveillance camera imagery and deep learning for detection, yet these struggle amidst scarce data and adverse environments.\\n - Date: 10 July, 2024\\n\\n2. **Intercepting Unauthorized Aerial Robots in Controlled Airspace Using Reinforcement Learning**\\n - Authors: Francisco Giral, Ignacio Gómez, Soledad Le Clainche\\n - Abstract: Ensuring the safe and efficient operation of airspace, particularly in urban environments and near critical infrastructure, necessitates effective methods to intercept unauthorized or non-cooperative UAVs. This work addresses the critical need for robust, adaptive systems capable of managing such scenarios.\\n - Date: 9 July, 2024\\n\\n3. **Efficient Materials Informatics between Rockets and Electrons**\\n - Authors: Adam M. Krajewski\\n - Abstract: This paper discusses the distinct efforts existing at three general scales of abstractions of what a material is - atomistic, physical, and design. At each, an efficient materials informatics is being built from the ground up based on the fundamental understanding of the underlying prior knowledge, including the data.\\n - Date: 5 July, 2024\\n\\n4. **ObfuscaTune: Obfuscated Offsite Fine-tuning and Inference of Proprietary LLMs on Private Datasets**\\n - Authors: Ahmed Frikha, Nassim Walha, Ricardo Mendes, Krishna Kanth Nakka, Xue Jiang, Xuebing Zhou\\n - Abstract: This paper proposes ObfuscaTune, a novel, efficient, and fully utility-preserving approach that combines a simple yet effective method to ensure the confidentiality of both the model and the data during offsite fine-tuning on a third-party cloud provider.\\n - Date: 3 July, 2024\\n\\n5. **MG-Verilog: Multi-grained Dataset Towards Enhanced LLM-assisted Verilog Generation**\\n - Authors: Yongan Zhang, Zhongzhi Yu, Yonggan Fu, Cheng Wan, Yingyan Celine Lin\\n - Abstract: This paper discusses the necessity of providing domain-specific data during inference, fine-tuning, or pre-training to effectively leverage LLMs in hardware design. Existing publicly available hardware datasets are often limited in size, complexity, or detail, which hinders the effectiveness of LLMs in this domain.\\n - Date: 1 July, 2024\\n\\n6. **The Future of Aerial Communications: A Survey of IRS-Enhanced UAV Communication Technologies**\\n - Authors: Zina Chkirbene, Ala Gouissem, Ridha Hamila, Devrim Unal\\n - Abstract: The advent of Reflecting Surfaces (IRS) and Unmanned Aerial Vehicles (UAVs) is setting a new benchmark in the field of wireless communications. IRS, with their groundbreaking ability to manipulate electromagnetic waves, have opened avenues for substantial enhancements in signal quality, network efficiency, and spectral usage.\\n - Date: 2 June, 2024\\n\\n7. **Scalable and RISC-V Programmable Near-Memory Computing Architectures for Edge Nodes**\\n - Authors: Michele Caon, Clément Choné, Pasquale Davide Schiavone, Alexandre Levisse, Guido Masera, Maurizio Martina, David Atienza\\n - Abstract: The widespread adoption of data-centric algorithms, particularly AI and ML, has exposed the limitations of centralized processing, driving the need for scalable and programmable near-memory computing architectures for edge nodes.\\n - Date: 20 June, 2024\\n\\n8. **Enhancing robustness of data-driven SHM models: adversarial training with circle loss**\\n - Authors: Xiangli Yang, Xijie Deng, Hanwei Zhang, Yang Zou, Jianxi Yang\\n - Abstract: Structural health monitoring (SHM) is critical to safeguarding the safety and reliability of aerospace, civil, and mechanical infrastructures. This paper discusses the use of adversarial training with circle loss to enhance the robustness of data-driven SHM models.\\n - Date: 20 June, 2024\\n\\n9. **Understanding Pedestrian Movement Using Urban Sensing Technologies: The Promise of Audio-based Sensors**\\n - Authors: Chaeyeon Han, Pavan Seshadri, Yiwei Ding, Noah Posner, Bon Woo Koo, Animesh Agrawal, Alexander Lerch, Subhrajit Guhathakurta\\n - Abstract: Understanding pedestrian volumes and flows is essential for designing safer and more attractive pedestrian infrastructures. This study discusses a new approach to scale up urban sensing of people with the help of novel audio-based technology.\\n - Date: 14 June, 2024\\n\\nASK_USER_HELP: Deshraj, I have found several papers that might be of interest to you. Would you like to proceed with any specific papers from the list above, or should I refine the search further?\\n\" status='NOT_SURE' url='https://arxiv.org/search/?query=Artificial+Intelligence+Machine+Learning+Infrastructure&searchtype=all&source=header' screenshot='' session_id='ff2ee9ef-60d4-4436-bc36-a81d94e0f410' metadata=Metadata(step_count=9, processing_time=66, temperature=0.2)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Create prompt and browse arXiv\n",
+ "prompt = f\"{command}\\n My past memories: {relevant_memories_text}\"\n",
+ "browse_result = multion.browse(cmd=prompt, url=\"https://arxiv.org/\")\n",
+ "print(browse_result)"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}
diff --git a/cookbooks/multion_travel_agent.ipynb b/cookbooks/multion_travel_agent.ipynb
index 19633707..f9211da1 100644
--- a/cookbooks/multion_travel_agent.ipynb
+++ b/cookbooks/multion_travel_agent.ipynb
@@ -1,306 +1,296 @@
{
- "cells": [
- {
- "cell_type": "code",
- "source": [
- "!pip install mem0ai"
- ],
- "metadata": {
- "id": "fu3euPKZsbaC"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "U2VC_0FElQid"
- },
- "outputs": [],
- "source": [
- "import os\n",
- "from openai import OpenAI\n",
- "from mem0 import MemoryClient\n",
- "from multion.client import MultiOn\n",
- "\n",
- "# Configuration\n",
- "OPENAI_API_KEY = 'sk-xxx' # Replace with your actual OpenAI API key\n",
- "MULTION_API_KEY = 'xx' # Replace with your actual MultiOn API key\n",
- "MEM0_API_KEY = 'xx' # Replace with your actual Mem0 API key\n",
- "USER_ID = \"test_travel_agent\"\n",
- "\n",
- "# Set up OpenAI API key\n",
- "os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY\n",
- "\n",
- "# Initialize Mem0 and MultiOn\n",
- "memory = MemoryClient(api_key=MEM0_API_KEY)\n",
- "multion = MultiOn(api_key=MULTION_API_KEY)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "sq-OdPHKlQie",
- "outputId": "1d605222-0bf5-4ac9-99b9-6059b502c20b"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'message': 'Memory added successfully!'}"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# Add conversation to Mem0\n",
- "conversation = [\n",
- " {\n",
- " \"role\": \"user\",\n",
- " \"content\": \"What are the best travel destinations in the world?\"\n",
- " },\n",
- " {\n",
- " \"role\": \"assistant\",\n",
- " \"content\": \"Could you please specify your interests or the type of travel information you are looking for? This will help me find the most relevant information for you.\"\n",
- " },\n",
- " {\n",
- " \"role\": \"user\",\n",
- " \"content\": \"Sure, I want to travel to San Francisco.\"\n",
- " },\n",
- " {\n",
- " \"role\": \"assistant\",\n",
- " \"content\": \"\"\"\n",
- " Based on the information gathered from TripAdvisor, here are some popular attractions, activities, and travel tips for San Francisco:\n",
- "\n",
- " 1. **Golden Gate Bridge**: A must-see iconic landmark.\n",
- " 2. **Alcatraz Island**: Famous former prison offering tours.\n",
- " 3. **Fisherman's Wharf**: Popular tourist area with shops, restaurants, and sea lions.\n",
- " 4. **Chinatown**: The largest Chinatown outside of Asia.\n",
- " 5. **Golden Gate Park**: Large urban park with gardens, museums, and recreational activities.\n",
- " 6. **Cable Cars**: Historic streetcars offering a unique way to see the city.\n",
- " 7. **Exploratorium**: Interactive science museum.\n",
- " 8. **San Francisco Museum of Modern Art (SFMOMA)**: Modern and contemporary art museum.\n",
- " 9. **Lombard Street**: Known for its steep, one-block section with eight hairpin turns.\n",
- " 10. **Union Square**: Major shopping and cultural hub.\n",
- "\n",
- " Travel Tips:\n",
- " - **Weather**: San Francisco has a mild climate, but it can be foggy and windy. Dress in layers.\n",
- " - **Transportation**: Use public transportation like BART, Muni, and cable cars to get around.\n",
- " - **Safety**: Be aware of your surroundings, especially in crowded tourist areas.\n",
- " - **Dining**: Try local specialties like sourdough bread, seafood, and Mission-style burritos.\n",
- " \"\"\"\n",
- " },\n",
- " {\n",
- " \"role\": \"user\",\n",
- " \"content\": \"Show me hotels around Golden Gate Bridge.\"\n",
- " },\n",
- " {\n",
- " \"role\": \"assistant\",\n",
- " \"content\": \"\"\"\n",
- " The search results for hotels around Golden Gate Bridge in San Francisco include:\n",
- "\n",
- " 1. Hilton Hotels In San Francisco - Hotel Near Fishermans Wharf (hilton.com)\n",
- " 2. The 10 Closest Hotels to Golden Gate Bridge (tripadvisor.com)\n",
- " 3. Hotels near Golden Gate Bridge (expedia.com)\n",
- " 4. Hotels near Golden Gate Bridge (hotels.com)\n",
- " 5. Holiday Inn Express & Suites San Francisco Fishermans Wharf, an IHG Hotel $146 (1.8K) 3-star hotel Golden Gate Bridge • 3.5 mi DEAL 19% less than usual\n",
- " 6. Holiday Inn San Francisco-Golden Gateway, an IHG Hotel $151 (3.5K) 3-star hotel Golden Gate Bridge • 3.7 mi Casual hotel with dining, a bar & a pool\n",
- " 7. Hotel Zephyr San Francisco $159 (3.8K) 4-star hotel Golden Gate Bridge • 3.7 mi Nautical-themed lodging with bay views\n",
- " 8. Lodge at the Presidio\n",
- " 9. The Inn Above Tide\n",
- " 10. Cavallo Point\n",
- " 11. Casa Madrona Hotel and Spa\n",
- " 12. Cow Hollow Inn and Suites\n",
- " 13. Samesun San Francisco\n",
- " 14. Inn on Broadway\n",
- " 15. Coventry Motor Inn\n",
- " 16. HI San Francisco Fisherman's Wharf Hostel\n",
- " 17. Loews Regency San Francisco Hotel\n",
- " 18. Fairmont Heritage Place Ghirardelli Square\n",
- " 19. Hotel Drisco Pacific Heights\n",
- " 20. Travelodge by Wyndham Presidio San Francisco\n",
- " \"\"\"\n",
- " }\n",
- "]\n",
- "\n",
- "memory.add(conversation, user_id=USER_ID)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "hO8z9aNTlQif"
- },
- "outputs": [],
- "source": [
- "def get_travel_info(question, use_memory=True):\n",
- " \"\"\"\n",
- " Get travel information based on user's question and optionally their preferences from memory.\n",
- "\n",
- " \"\"\"\n",
- " if use_memory:\n",
- " previous_memories = memory.search(question, user_id=USER_ID)\n",
- " relevant_memories_text = \"\"\n",
- " if previous_memories:\n",
- " print(\"Using previous memories to enhance the search...\")\n",
- " relevant_memories_text = '\\n'.join(mem[\"memory\"] for mem in previous_memories)\n",
- "\n",
- " command = \"Find travel information based on my interests:\"\n",
- " prompt = f\"{command}\\n Question: {question} \\n My preferences: {relevant_memories_text}\"\n",
- " else:\n",
- " command = \"Find travel information based on my interests:\"\n",
- " prompt = f\"{command}\\n Question: {question}\"\n",
- "\n",
- "\n",
- " print(\"Searching for travel information...\")\n",
- " browse_result = multion.browse(cmd=prompt)\n",
- " return browse_result.message"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Wp2xpzMrlQig"
- },
- "source": [
- "## Example 1"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "bPRPwqsplQig"
- },
- "outputs": [],
- "source": [
- "question = \"Show me flight details for it.\"\n",
- "answer_without_memory = get_travel_info(question, use_memory=False)\n",
- "answer_with_memory = get_travel_info(question, use_memory=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "a76ifa2HlQig"
- },
- "source": [
- "| Without Memory | With Memory |\n",
- "|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
- "| I have performed a Google search for \"flight details\" and reviewed the search results. Here are some relevant links and information: | Memorizing the following information: Flight details for San Francisco: |\n",
- "| 1. **FlightStats Global Flight Tracker** - Track the real-time flight status of your flight. See if your flight has been delayed or cancelled and track the live status.
[Flight Tracker - FlightStats](https://www.flightstats.com/flight-tracker/search) | 1. Prices from $232. Depart Thursday, August 22. Return Thursday, August 29.
2. Prices from $216. Depart Friday, August 23. Return Friday, August 30.
3. Prices from $236. Depart Saturday, August 24. Return Saturday, August 31.
4. Prices from $215. Depart Sunday, August 25. Return Sunday, September 1. |\n",
- "| 2. **FlightAware - Flight Tracker** - Track live flights worldwide, see flight cancellations, and browse by airport.
[FlightAware - Flight Tracker](https://www.flightaware.com) | 5. Prices from $218. Depart Monday, August 26. Return Monday, September 2.
6. Prices from $211. Depart Tuesday, August 27. Return Tuesday, September 3.
7. Prices from $198. Depart Wednesday, August 28. Return Wednesday, September 4.
8. Prices from $218. Depart Thursday, August 29. Return Thursday, September 5. |\n",
- "| 3. **Google Flights** - Show flights based on your search.
[Google Flights](https://www.google.com/flights) | 9. Prices from $194. Depart Friday, August 30. Return Friday, September 6.
10. Prices from $218. Depart Saturday, August 31. Return Saturday, September 7.
11. Prices from $212. Depart Sunday, September 1. Return Sunday, September 8.
12. Prices from $247. Depart Monday, September 2. Return Monday, September 9. |\n",
- "| | 13. Prices from $212. Depart Tuesday, September 3. Return Tuesday, September 10.
14. Prices from $203. Depart Wednesday, September 4. Return Wednesday, September 11.
15. Prices from $242. Depart Thursday, September 5. Return Thursday, September 12.
16. Prices from $191. Depart Friday, September 6. Return Friday, September 13. |\n",
- "| | 17. Prices from $215. Depart Saturday, September 7. Return Saturday, September 14.
18. Prices from $229. Depart Sunday, September 8. Return Sunday, September 15.
19. Prices from $183. Depart Monday, September 9. Return Monday, September 16.
65. Prices from $194. Depart Friday, October 25. Return Friday, November 1. |\n",
- "| | 66. Prices from $205. Depart Saturday, October 26. Return Saturday, November 2.
67. Prices from $241. Depart Sunday, October 27. Return Sunday, November 3. |\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "0cXpiAwMlQig"
- },
- "source": [
- "## Example 2"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "LpprKfpslQih"
- },
- "outputs": [],
- "source": [
- "question = \"What places to visit there?\"\n",
- "answer_without_memory = get_travel_info(question, use_memory=False)\n",
- "answer_with_memory = get_travel_info(question, use_memory=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "kpfjeY1_lQih"
- },
- "source": [
- "| Without Memory | With Memory |\n",
- "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
- "| Based on the information gathered, here are some top travel destinations to consider visiting: | Based on the information gathered, here are some top places to visit in San Francisco: |\n",
- "| 1. **Paris**: Known for its iconic attractions like the Eiffel Tower and the Louvre, Paris offers quaint cafes, trendy shopping districts, and beautiful Haussmann architecture. It's a city where you can always discover something new with each visit. | 1. **Golden Gate Bridge** - An iconic symbol of San Francisco, perfect for walking, biking, or simply enjoying the view.
2. **Alcatraz Island** - The historic former prison offers tours and insights into its storied past.
3. **Fisherman's Wharf** - A bustling waterfront area known for its seafood, shopping, and attractions like Pier 39.
4. **Golden Gate Park** - A large urban park with gardens, museums, and recreational activities.
5. **Chinatown San Francisco** - One of the oldest and most famous Chinatowns in North America, offering unique shops and delicious food.
6. **Coit Tower** - Offers panoramic views of the city and murals depicting San Francisco's history.
7. **Lands End** - A beautiful coastal trail with stunning views of the Pacific Ocean and the Golden Gate Bridge.
8. **Palace of Fine Arts** - A picturesque structure and park, perfect for a leisurely stroll or photo opportunities.
9. **Crissy Field & The Presidio Tunnel Tops** - Great for outdoor activities and scenic views of the bay. |\n",
- "| 2. **Bora Bora**: This small island in French Polynesia is famous for its stunning turquoise waters, luxurious overwater bungalows, and vibrant coral reefs. It's a popular destination for honeymooners and those seeking a tropical paradise. | |\n",
- "| 3. **Glacier National Park**: Located in Montana, USA, this park is known for its breathtaking landscapes, including rugged mountains, pristine lakes, and diverse wildlife. It's a haven for outdoor enthusiasts and hikers. | |\n",
- "| 4. **Rome**: The capital of Italy, Rome is rich in history and culture, featuring landmarks such as the Colosseum, the Vatican, and the Pantheon. It's a city where ancient history meets modern life. | |\n",
- "| 5. **Swiss Alps**: Renowned for their stunning natural beauty, the Swiss Alps offer opportunities for skiing, hiking, and enjoying picturesque mountain villages. | |\n",
- "| 6. **Maui**: One of Hawaii's most popular islands, Maui is known for its beautiful beaches, lush rainforests, and the scenic Hana Highway. It's a great destination for both relaxation and adventure. | |\n",
- "| 7. **London, England**: A vibrant city with a mix of historical landmarks like the Tower of London and modern attractions such as the London Eye. London offers diverse cultural experiences, world-class museums, and a bustling nightlife. | |\n",
- "| 8. **Maldives**: This tropical paradise in the Indian Ocean is famous for its crystal-clear waters, luxurious resorts, and abundant marine life. It's an ideal destination for snorkeling, diving, and relaxation. | |\n",
- "| 9. **Turks & Caicos**: Known for its pristine beaches and turquoise waters, this Caribbean destination is perfect for water sports, beach lounging, and exploring coral reefs. | |\n",
- "| 10. **Tokyo**: Japan's bustling capital offers a unique blend of traditional and modern attractions, from ancient temples to futuristic skyscrapers. Tokyo is also known for its vibrant food scene and shopping districts. | |\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "XdpkcMrclQih"
- },
- "source": [
- "## Example 3"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Nntl2FxulQih"
- },
- "outputs": [],
- "source": [
- "question = \"What the weather there?\"\n",
- "answer_without_memory = get_travel_info(question, use_memory=False)\n",
- "answer_with_memory = get_travel_info(question, use_memory=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "yt2pj1irlQih"
- },
- "source": [
- "| Without Memory | With Memory |\n",
- "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
- "| The current weather in Paris is light rain with a temperature of 67°F. The precipitation is at 50%, humidity is 95%, and the wind speed is 5 mph. | The current weather in San Francisco is as follows:
- **Temperature**: 59°F
- **Condition**: Clear with periodic clouds
- **Precipitation**: 3%
- **Humidity**: 87%
- **Wind**: 12 mph |\n"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": ".venv",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.3"
- },
- "colab": {
- "provenance": []
- }
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fu3euPKZsbaC"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install mem0ai"
+ ]
},
- "nbformat": 4,
- "nbformat_minor": 0
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "U2VC_0FElQid"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from openai import OpenAI\n",
+ "from mem0 import MemoryClient\n",
+ "from multion.client import MultiOn\n",
+ "\n",
+ "# Configuration\n",
+ "OPENAI_API_KEY = \"sk-xxx\" # Replace with your actual OpenAI API key\n",
+ "MULTION_API_KEY = \"xx\" # Replace with your actual MultiOn API key\n",
+ "MEM0_API_KEY = \"xx\" # Replace with your actual Mem0 API key\n",
+ "USER_ID = \"test_travel_agent\"\n",
+ "\n",
+ "# Set up OpenAI API key\n",
+ "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n",
+ "\n",
+ "# Initialize Mem0 and MultiOn\n",
+ "memory = MemoryClient(api_key=MEM0_API_KEY)\n",
+ "multion = MultiOn(api_key=MULTION_API_KEY)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "sq-OdPHKlQie",
+ "outputId": "1d605222-0bf5-4ac9-99b9-6059b502c20b"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'message': 'Memory added successfully!'}"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Add conversation to Mem0\n",
+ "conversation = [\n",
+ " {\"role\": \"user\", \"content\": \"What are the best travel destinations in the world?\"},\n",
+ " {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": \"Could you please specify your interests or the type of travel information you are looking for? This will help me find the most relevant information for you.\",\n",
+ " },\n",
+ " {\"role\": \"user\", \"content\": \"Sure, I want to travel to San Francisco.\"},\n",
+ " {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": \"\"\"\n",
+ " Based on the information gathered from TripAdvisor, here are some popular attractions, activities, and travel tips for San Francisco:\n",
+ "\n",
+ " 1. **Golden Gate Bridge**: A must-see iconic landmark.\n",
+ " 2. **Alcatraz Island**: Famous former prison offering tours.\n",
+ " 3. **Fisherman's Wharf**: Popular tourist area with shops, restaurants, and sea lions.\n",
+ " 4. **Chinatown**: The largest Chinatown outside of Asia.\n",
+ " 5. **Golden Gate Park**: Large urban park with gardens, museums, and recreational activities.\n",
+ " 6. **Cable Cars**: Historic streetcars offering a unique way to see the city.\n",
+ " 7. **Exploratorium**: Interactive science museum.\n",
+ " 8. **San Francisco Museum of Modern Art (SFMOMA)**: Modern and contemporary art museum.\n",
+ " 9. **Lombard Street**: Known for its steep, one-block section with eight hairpin turns.\n",
+ " 10. **Union Square**: Major shopping and cultural hub.\n",
+ "\n",
+ " Travel Tips:\n",
+ " - **Weather**: San Francisco has a mild climate, but it can be foggy and windy. Dress in layers.\n",
+ " - **Transportation**: Use public transportation like BART, Muni, and cable cars to get around.\n",
+ " - **Safety**: Be aware of your surroundings, especially in crowded tourist areas.\n",
+ " - **Dining**: Try local specialties like sourdough bread, seafood, and Mission-style burritos.\n",
+ " \"\"\",\n",
+ " },\n",
+ " {\"role\": \"user\", \"content\": \"Show me hotels around Golden Gate Bridge.\"},\n",
+ " {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": \"\"\"\n",
+ " The search results for hotels around Golden Gate Bridge in San Francisco include:\n",
+ "\n",
+ " 1. Hilton Hotels In San Francisco - Hotel Near Fishermans Wharf (hilton.com)\n",
+ " 2. The 10 Closest Hotels to Golden Gate Bridge (tripadvisor.com)\n",
+ " 3. Hotels near Golden Gate Bridge (expedia.com)\n",
+ " 4. Hotels near Golden Gate Bridge (hotels.com)\n",
+ " 5. Holiday Inn Express & Suites San Francisco Fishermans Wharf, an IHG Hotel $146 (1.8K) 3-star hotel Golden Gate Bridge • 3.5 mi DEAL 19% less than usual\n",
+ " 6. Holiday Inn San Francisco-Golden Gateway, an IHG Hotel $151 (3.5K) 3-star hotel Golden Gate Bridge • 3.7 mi Casual hotel with dining, a bar & a pool\n",
+ " 7. Hotel Zephyr San Francisco $159 (3.8K) 4-star hotel Golden Gate Bridge • 3.7 mi Nautical-themed lodging with bay views\n",
+ " 8. Lodge at the Presidio\n",
+ " 9. The Inn Above Tide\n",
+ " 10. Cavallo Point\n",
+ " 11. Casa Madrona Hotel and Spa\n",
+ " 12. Cow Hollow Inn and Suites\n",
+ " 13. Samesun San Francisco\n",
+ " 14. Inn on Broadway\n",
+ " 15. Coventry Motor Inn\n",
+ " 16. HI San Francisco Fisherman's Wharf Hostel\n",
+ " 17. Loews Regency San Francisco Hotel\n",
+ " 18. Fairmont Heritage Place Ghirardelli Square\n",
+ " 19. Hotel Drisco Pacific Heights\n",
+ " 20. Travelodge by Wyndham Presidio San Francisco\n",
+ " \"\"\",\n",
+ " },\n",
+ "]\n",
+ "\n",
+ "memory.add(conversation, user_id=USER_ID)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "hO8z9aNTlQif"
+ },
+ "outputs": [],
+ "source": [
+ "def get_travel_info(question, use_memory=True):\n",
+ " \"\"\"\n",
+ " Get travel information based on user's question and optionally their preferences from memory.\n",
+ "\n",
+ " \"\"\"\n",
+ " if use_memory:\n",
+ " previous_memories = memory.search(question, user_id=USER_ID)\n",
+ " relevant_memories_text = \"\"\n",
+ " if previous_memories:\n",
+ " print(\"Using previous memories to enhance the search...\")\n",
+ " relevant_memories_text = \"\\n\".join(mem[\"memory\"] for mem in previous_memories)\n",
+ "\n",
+ " command = \"Find travel information based on my interests:\"\n",
+ " prompt = f\"{command}\\n Question: {question} \\n My preferences: {relevant_memories_text}\"\n",
+ " else:\n",
+ " command = \"Find travel information based on my interests:\"\n",
+ " prompt = f\"{command}\\n Question: {question}\"\n",
+ "\n",
+ " print(\"Searching for travel information...\")\n",
+ " browse_result = multion.browse(cmd=prompt)\n",
+ " return browse_result.message"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Wp2xpzMrlQig"
+ },
+ "source": [
+ "## Example 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "bPRPwqsplQig"
+ },
+ "outputs": [],
+ "source": [
+ "question = \"Show me flight details for it.\"\n",
+ "answer_without_memory = get_travel_info(question, use_memory=False)\n",
+ "answer_with_memory = get_travel_info(question, use_memory=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a76ifa2HlQig"
+ },
+ "source": [
+ "| Without Memory | With Memory |\n",
+ "|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
+ "| I have performed a Google search for \"flight details\" and reviewed the search results. Here are some relevant links and information: | Memorizing the following information: Flight details for San Francisco: |\n",
+ "| 1. **FlightStats Global Flight Tracker** - Track the real-time flight status of your flight. See if your flight has been delayed or cancelled and track the live status.
[Flight Tracker - FlightStats](https://www.flightstats.com/flight-tracker/search) | 1. Prices from $232. Depart Thursday, August 22. Return Thursday, August 29.
2. Prices from $216. Depart Friday, August 23. Return Friday, August 30.
3. Prices from $236. Depart Saturday, August 24. Return Saturday, August 31.
4. Prices from $215. Depart Sunday, August 25. Return Sunday, September 1. |\n",
+ "| 2. **FlightAware - Flight Tracker** - Track live flights worldwide, see flight cancellations, and browse by airport.
[FlightAware - Flight Tracker](https://www.flightaware.com) | 5. Prices from $218. Depart Monday, August 26. Return Monday, September 2.
6. Prices from $211. Depart Tuesday, August 27. Return Tuesday, September 3.
7. Prices from $198. Depart Wednesday, August 28. Return Wednesday, September 4.
8. Prices from $218. Depart Thursday, August 29. Return Thursday, September 5. |\n",
+ "| 3. **Google Flights** - Show flights based on your search.
[Google Flights](https://www.google.com/flights) | 9. Prices from $194. Depart Friday, August 30. Return Friday, September 6.
10. Prices from $218. Depart Saturday, August 31. Return Saturday, September 7.
11. Prices from $212. Depart Sunday, September 1. Return Sunday, September 8.
12. Prices from $247. Depart Monday, September 2. Return Monday, September 9. |\n",
+ "| | 13. Prices from $212. Depart Tuesday, September 3. Return Tuesday, September 10.
14. Prices from $203. Depart Wednesday, September 4. Return Wednesday, September 11.
15. Prices from $242. Depart Thursday, September 5. Return Thursday, September 12.
16. Prices from $191. Depart Friday, September 6. Return Friday, September 13. |\n",
+ "| | 17. Prices from $215. Depart Saturday, September 7. Return Saturday, September 14.
18. Prices from $229. Depart Sunday, September 8. Return Sunday, September 15.
19. Prices from $183. Depart Monday, September 9. Return Monday, September 16.
65. Prices from $194. Depart Friday, October 25. Return Friday, November 1. |\n",
+ "| | 66. Prices from $205. Depart Saturday, October 26. Return Saturday, November 2.
67. Prices from $241. Depart Sunday, October 27. Return Sunday, November 3. |\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0cXpiAwMlQig"
+ },
+ "source": [
+ "## Example 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "LpprKfpslQih"
+ },
+ "outputs": [],
+ "source": [
+ "question = \"What places to visit there?\"\n",
+ "answer_without_memory = get_travel_info(question, use_memory=False)\n",
+ "answer_with_memory = get_travel_info(question, use_memory=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kpfjeY1_lQih"
+ },
+ "source": [
+ "| Without Memory | With Memory |\n",
+ "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
+ "| Based on the information gathered, here are some top travel destinations to consider visiting: | Based on the information gathered, here are some top places to visit in San Francisco: |\n",
+ "| 1. **Paris**: Known for its iconic attractions like the Eiffel Tower and the Louvre, Paris offers quaint cafes, trendy shopping districts, and beautiful Haussmann architecture. It's a city where you can always discover something new with each visit. | 1. **Golden Gate Bridge** - An iconic symbol of San Francisco, perfect for walking, biking, or simply enjoying the view.
2. **Alcatraz Island** - The historic former prison offers tours and insights into its storied past.
3. **Fisherman's Wharf** - A bustling waterfront area known for its seafood, shopping, and attractions like Pier 39.
4. **Golden Gate Park** - A large urban park with gardens, museums, and recreational activities.
5. **Chinatown San Francisco** - One of the oldest and most famous Chinatowns in North America, offering unique shops and delicious food.
6. **Coit Tower** - Offers panoramic views of the city and murals depicting San Francisco's history.
7. **Lands End** - A beautiful coastal trail with stunning views of the Pacific Ocean and the Golden Gate Bridge.
8. **Palace of Fine Arts** - A picturesque structure and park, perfect for a leisurely stroll or photo opportunities.
9. **Crissy Field & The Presidio Tunnel Tops** - Great for outdoor activities and scenic views of the bay. |\n",
+ "| 2. **Bora Bora**: This small island in French Polynesia is famous for its stunning turquoise waters, luxurious overwater bungalows, and vibrant coral reefs. It's a popular destination for honeymooners and those seeking a tropical paradise. | |\n",
+ "| 3. **Glacier National Park**: Located in Montana, USA, this park is known for its breathtaking landscapes, including rugged mountains, pristine lakes, and diverse wildlife. It's a haven for outdoor enthusiasts and hikers. | |\n",
+ "| 4. **Rome**: The capital of Italy, Rome is rich in history and culture, featuring landmarks such as the Colosseum, the Vatican, and the Pantheon. It's a city where ancient history meets modern life. | |\n",
+ "| 5. **Swiss Alps**: Renowned for their stunning natural beauty, the Swiss Alps offer opportunities for skiing, hiking, and enjoying picturesque mountain villages. | |\n",
+ "| 6. **Maui**: One of Hawaii's most popular islands, Maui is known for its beautiful beaches, lush rainforests, and the scenic Hana Highway. It's a great destination for both relaxation and adventure. | |\n",
+ "| 7. **London, England**: A vibrant city with a mix of historical landmarks like the Tower of London and modern attractions such as the London Eye. London offers diverse cultural experiences, world-class museums, and a bustling nightlife. | |\n",
+ "| 8. **Maldives**: This tropical paradise in the Indian Ocean is famous for its crystal-clear waters, luxurious resorts, and abundant marine life. It's an ideal destination for snorkeling, diving, and relaxation. | |\n",
+ "| 9. **Turks & Caicos**: Known for its pristine beaches and turquoise waters, this Caribbean destination is perfect for water sports, beach lounging, and exploring coral reefs. | |\n",
+ "| 10. **Tokyo**: Japan's bustling capital offers a unique blend of traditional and modern attractions, from ancient temples to futuristic skyscrapers. Tokyo is also known for its vibrant food scene and shopping districts. | |\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XdpkcMrclQih"
+ },
+ "source": [
+ "## Example 3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Nntl2FxulQih"
+ },
+ "outputs": [],
+ "source": [
+ "question = \"What the weather there?\"\n",
+ "answer_without_memory = get_travel_info(question, use_memory=False)\n",
+ "answer_with_memory = get_travel_info(question, use_memory=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yt2pj1irlQih"
+ },
+ "source": [
+ "| Without Memory | With Memory |\n",
+ "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
+ "| The current weather in Paris is light rain with a temperature of 67°F. The precipitation is at 50%, humidity is 95%, and the wind speed is 5 mph. | The current weather in San Francisco is as follows:
- **Temperature**: 59°F
- **Condition**: Clear with periodic clouds
- **Precipitation**: 3%
- **Humidity**: 87%
- **Wind**: 12 mph |\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}
\ No newline at end of file
diff --git a/mem0/client/main.py b/mem0/client/main.py
index ae0519b9..e94c6f83 100644
--- a/mem0/client/main.py
+++ b/mem0/client/main.py
@@ -10,7 +10,11 @@ from mem0.memory.setup import setup_config
from mem0.memory.telemetry import capture_client_event
logger = logging.getLogger(__name__)
-warnings.filterwarnings('always', category=DeprecationWarning, message="The 'session_id' parameter is deprecated. User 'run_id' instead.")
+warnings.filterwarnings(
+ "always",
+ category=DeprecationWarning,
+ message="The 'session_id' parameter is deprecated. User 'run_id' instead.",
+)
# Setup user config
setup_config()
@@ -82,14 +86,10 @@ class MemoryClient:
response = self.client.get("/v1/memories/", params={"user_id": "test"})
response.raise_for_status()
except httpx.HTTPStatusError:
- raise ValueError(
- "Invalid API Key. Please get a valid API Key from https://app.mem0.ai"
- )
+ raise ValueError("Invalid API Key. Please get a valid API Key from https://app.mem0.ai")
@api_error_handler
- def add(
- self, messages: Union[str, List[Dict[str, str]]], **kwargs
- ) -> Dict[str, Any]:
+ def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
"""Add a new memory.
Args:
@@ -253,9 +253,7 @@ class MemoryClient:
"""Delete all users, agents, or sessions."""
entities = self.users()
for entity in entities["results"]:
- response = self.client.delete(
- f"/v1/entities/{entity['type']}/{entity['id']}/"
- )
+ response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/")
response.raise_for_status()
capture_client_event("client.delete_users", self)
@@ -312,7 +310,7 @@ class MemoryClient:
"The 'session_id' parameter is deprecated and will be removed in version 0.1.20. "
"Use 'run_id' instead.",
DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
)
kwargs["run_id"] = kwargs.pop("session_id")
@@ -335,7 +333,7 @@ class MemoryClient:
"The 'session_id' parameter is deprecated and will be removed in version 0.1.20. "
"Use 'run_id' instead.",
DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
)
kwargs["run_id"] = kwargs.pop("session_id")
diff --git a/mem0/configs/base.py b/mem0/configs/base.py
index a83e00d6..55e09f27 100644
--- a/mem0/configs/base.py
+++ b/mem0/configs/base.py
@@ -17,18 +17,10 @@ class MemoryItem(BaseModel):
) # TODO After prompt changes from platform, update this
hash: Optional[str] = Field(None, description="The hash of the memory")
# The metadata value can be anything and not just string. Fix it
- metadata: Optional[Dict[str, Any]] = Field(
- None, description="Additional metadata for the text data"
- )
- score: Optional[float] = Field(
- None, description="The score associated with the text data"
- )
- created_at: Optional[str] = Field(
- None, description="The timestamp when the memory was created"
- )
- updated_at: Optional[str] = Field(
- None, description="The timestamp when the memory was updated"
- )
+ metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
+ score: Optional[float] = Field(None, description="The score associated with the text data")
+ created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
+ updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
class MemoryConfig(BaseModel):
@@ -60,7 +52,7 @@ class MemoryConfig(BaseModel):
description="Custom prompt for the memory",
default=None,
)
-
+
class AzureConfig(BaseModel):
"""
@@ -73,7 +65,10 @@ class AzureConfig(BaseModel):
api_version (str): The version of the Azure API being used.
"""
- api_key: str = Field(description="The API key used for authenticating with the Azure service.", default=None)
- azure_deployment : str = Field(description="The name of the Azure deployment.", default=None)
- azure_endpoint : str = Field(description="The endpoint URL for the Azure service.", default=None)
- api_version : str = Field(description="The version of the Azure API being used.", default=None)
+ api_key: str = Field(
+ description="The API key used for authenticating with the Azure service.",
+ default=None,
+ )
+ azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
+ azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
+ api_version: str = Field(description="The version of the Azure API being used.", default=None)
diff --git a/mem0/configs/embeddings/base.py b/mem0/configs/embeddings/base.py
index f4659dce..63245872 100644
--- a/mem0/configs/embeddings/base.py
+++ b/mem0/configs/embeddings/base.py
@@ -60,6 +60,6 @@ class BaseEmbedderConfig(ABC):
# Huggingface specific
self.model_kwargs = model_kwargs or {}
-
+
# AzureOpenAI specific
self.azure_kwargs = AzureConfig(**azure_kwargs) or {}
diff --git a/mem0/configs/prompts.py b/mem0/configs/prompts.py
index be00856b..d9192129 100644
--- a/mem0/configs/prompts.py
+++ b/mem0/configs/prompts.py
@@ -59,6 +59,7 @@ You should detect the language of the user input and record the facts in the sam
If you do not find anything relevant facts, user memories, and preferences in the below conversation, you can return an empty list corresponding to the "facts" key.
"""
+
def get_update_memory_messages(retrieved_old_memory_dict, response_content):
return f"""You are a smart memory manager which controls the memory of a system.
You can perform four operations: (1) add into the memory, (2) update the memory, (3) delete from the memory, and (4) no change.
diff --git a/mem0/configs/vector_stores/chroma.py b/mem0/configs/vector_stores/chroma.py
index 4a5ecbf5..1b0ef4e3 100644
--- a/mem0/configs/vector_stores/chroma.py
+++ b/mem0/configs/vector_stores/chroma.py
@@ -13,9 +13,7 @@ class ChromaDbConfig(BaseModel):
Client: ClassVar[type] = Client
collection_name: str = Field("mem0", description="Default name for the collection")
- client: Optional[Client] = Field(
- None, description="Existing ChromaDB client instance"
- )
+ client: Optional[Client] = Field(None, description="Existing ChromaDB client instance")
path: Optional[str] = Field(None, description="Path to the database directory")
host: Optional[str] = Field(None, description="Database connection remote host")
port: Optional[int] = Field(None, description="Database connection remote port")
diff --git a/mem0/configs/vector_stores/milvus.py b/mem0/configs/vector_stores/milvus.py
index 1e433df1..7578c6fc 100644
--- a/mem0/configs/vector_stores/milvus.py
+++ b/mem0/configs/vector_stores/milvus.py
@@ -1,22 +1,24 @@
from enum import Enum
-from typing import Dict, Any
-from pydantic import BaseModel, model_validator, Field
+from typing import Any, Dict
+
+from pydantic import BaseModel, Field, model_validator
class MetricType(str, Enum):
"""
Metric Constant for milvus/ zilliz server.
"""
+
def __str__(self) -> str:
return str(self.value)
-
+
L2 = "L2"
- IP = "IP"
- COSINE = "COSINE"
- HAMMING = "HAMMING"
- JACCARD = "JACCARD"
-
-
+ IP = "IP"
+ COSINE = "COSINE"
+ HAMMING = "HAMMING"
+ JACCARD = "JACCARD"
+
+
class MilvusDBConfig(BaseModel):
url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server")
token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.")
@@ -38,4 +40,4 @@ class MilvusDBConfig(BaseModel):
model_config = {
"arbitrary_types_allowed": True,
- }
\ No newline at end of file
+ }
diff --git a/mem0/configs/vector_stores/pgvector.py b/mem0/configs/vector_stores/pgvector.py
index df8dabf4..b81ed985 100644
--- a/mem0/configs/vector_stores/pgvector.py
+++ b/mem0/configs/vector_stores/pgvector.py
@@ -4,12 +4,9 @@ from pydantic import BaseModel, Field, model_validator
class PGVectorConfig(BaseModel):
-
dbname: str = Field("postgres", description="Default name for the database")
collection_name: str = Field("mem0", description="Default name for the collection")
- embedding_model_dims: Optional[int] = Field(
- 1536, description="Dimensions of the embedding model"
- )
+ embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
user: Optional[str] = Field(None, description="Database user")
password: Optional[str] = Field(None, description="Database password")
host: Optional[str] = Field(None, description="Database host. Default is localhost")
diff --git a/mem0/configs/vector_stores/qdrant.py b/mem0/configs/vector_stores/qdrant.py
index 10951db8..f8628d33 100644
--- a/mem0/configs/vector_stores/qdrant.py
+++ b/mem0/configs/vector_stores/qdrant.py
@@ -9,17 +9,11 @@ class QdrantConfig(BaseModel):
QdrantClient: ClassVar[type] = QdrantClient
collection_name: str = Field("mem0", description="Name of the collection")
- embedding_model_dims: Optional[int] = Field(
- 1536, description="Dimensions of the embedding model"
- )
- client: Optional[QdrantClient] = Field(
- None, description="Existing Qdrant client instance"
- )
+ embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
+ client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
host: Optional[str] = Field(None, description="Host address for Qdrant server")
port: Optional[int] = Field(None, description="Port for Qdrant server")
- path: Optional[str] = Field(
- "/tmp/qdrant", description="Path for local Qdrant database"
- )
+ path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
@@ -35,9 +29,7 @@ class QdrantConfig(BaseModel):
values.get("api_key"),
)
if not path and not (host and port) and not (url and api_key):
- raise ValueError(
- "Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided."
- )
+ raise ValueError("Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided.")
return values
@model_validator(mode="before")
diff --git a/mem0/embeddings/azure_openai.py b/mem0/embeddings/azure_openai.py
index 8e801ccd..d25cc00e 100644
--- a/mem0/embeddings/azure_openai.py
+++ b/mem0/embeddings/azure_openai.py
@@ -15,14 +15,14 @@ class AzureOpenAIEmbedding(EmbeddingBase):
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT")
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT")
api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION")
-
+
self.client = AzureOpenAI(
- azure_deployment=azure_deployment,
+ azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
- http_client=self.config.http_client
- )
+ http_client=self.config.http_client,
+ )
def embed(self, text):
"""
@@ -35,8 +35,4 @@ class AzureOpenAIEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
- return (
- self.client.embeddings.create(input=[text], model=self.config.model)
- .data[0]
- .embedding
- )
+ return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
diff --git a/mem0/embeddings/configs.py b/mem0/embeddings/configs.py
index 9e3848cf..21349344 100644
--- a/mem0/embeddings/configs.py
+++ b/mem0/embeddings/configs.py
@@ -8,9 +8,7 @@ class EmbedderConfig(BaseModel):
description="Provider of the embedding model (e.g., 'ollama', 'openai')",
default="openai",
)
- config: Optional[dict] = Field(
- description="Configuration for the specific embedding model", default={}
- )
+ config: Optional[dict] = Field(description="Configuration for the specific embedding model", default={})
@field_validator("config")
def validate_config(cls, v, values):
diff --git a/mem0/embeddings/ollama.py b/mem0/embeddings/ollama.py
index 2e7f3758..ae00368e 100644
--- a/mem0/embeddings/ollama.py
+++ b/mem0/embeddings/ollama.py
@@ -9,7 +9,7 @@ try:
from ollama import Client
except ImportError:
user_input = input("The 'ollama' library is required. Install it now? [y/N]: ")
- if user_input.lower() == 'y':
+ if user_input.lower() == "y":
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"])
from ollama import Client
diff --git a/mem0/embeddings/openai.py b/mem0/embeddings/openai.py
index be9195bf..b68b8ffc 100644
--- a/mem0/embeddings/openai.py
+++ b/mem0/embeddings/openai.py
@@ -29,8 +29,4 @@ class OpenAIEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
- return (
- self.client.embeddings.create(input=[text], model=self.config.model)
- .data[0]
- .embedding
- )
+ return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
diff --git a/mem0/embeddings/vertexai.py b/mem0/embeddings/vertexai.py
index 4839a2f3..bcdaaab2 100644
--- a/mem0/embeddings/vertexai.py
+++ b/mem0/embeddings/vertexai.py
@@ -6,6 +6,7 @@ from vertexai.language_models import TextEmbeddingModel
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
+
class VertexAI(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
@@ -34,6 +35,6 @@ class VertexAI(EmbeddingBase):
Returns:
list: The embedding vector.
"""
- embeddings = self.model.get_embeddings(texts=[text], output_dimensionality= self.config.embedding_dims)
-
+ embeddings = self.model.get_embeddings(texts=[text], output_dimensionality=self.config.embedding_dims)
+
return embeddings[0].values
diff --git a/mem0/graphs/configs.py b/mem0/graphs/configs.py
index 033637c3..c14249ad 100644
--- a/mem0/graphs/configs.py
+++ b/mem0/graphs/configs.py
@@ -18,28 +18,16 @@ class Neo4jConfig(BaseModel):
values.get("password"),
)
if not url or not username or not password:
- raise ValueError(
- "Please provide 'url', 'username' and 'password'."
- )
+ raise ValueError("Please provide 'url', 'username' and 'password'.")
return values
class GraphStoreConfig(BaseModel):
- provider: str = Field(
- description="Provider of the data store (e.g., 'neo4j')",
- default="neo4j"
- )
- config: Neo4jConfig = Field(
- description="Configuration for the specific data store",
- default=None
- )
- llm: Optional[LlmConfig] = Field(
- description="LLM configuration for querying the graph store",
- default=None
- )
+ provider: str = Field(description="Provider of the data store (e.g., 'neo4j')", default="neo4j")
+ config: Neo4jConfig = Field(description="Configuration for the specific data store", default=None)
+ llm: Optional[LlmConfig] = Field(description="LLM configuration for querying the graph store", default=None)
custom_prompt: Optional[str] = Field(
- description="Custom prompt to fetch entities from the given text",
- default=None
+ description="Custom prompt to fetch entities from the given text", default=None
)
@field_validator("config")
@@ -49,4 +37,3 @@ class GraphStoreConfig(BaseModel):
return Neo4jConfig(**v.model_dump())
else:
raise ValueError(f"Unsupported graph store provider: {provider}")
-
diff --git a/mem0/graphs/tools.py b/mem0/graphs/tools.py
index d7279242..1fdbe91f 100644
--- a/mem0/graphs/tools.py
+++ b/mem0/graphs/tools.py
@@ -1,4 +1,3 @@
-
UPDATE_MEMORY_TOOL_GRAPH = {
"type": "function",
"function": {
@@ -9,21 +8,21 @@ UPDATE_MEMORY_TOOL_GRAPH = {
"properties": {
"source": {
"type": "string",
- "description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph."
+ "description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph.",
},
"destination": {
"type": "string",
- "description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph."
+ "description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph.",
},
"relationship": {
"type": "string",
- "description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
- }
+ "description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
+ },
},
"required": ["source", "destination", "relationship"],
- "additionalProperties": False
- }
- }
+ "additionalProperties": False,
+ },
+ },
}
ADD_MEMORY_TOOL_GRAPH = {
@@ -36,29 +35,35 @@ ADD_MEMORY_TOOL_GRAPH = {
"properties": {
"source": {
"type": "string",
- "description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created."
+ "description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created.",
},
"destination": {
"type": "string",
- "description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created."
+ "description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created.",
},
"relationship": {
"type": "string",
- "description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
+ "description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
},
"source_type": {
"type": "string",
- "description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph."
+ "description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph.",
},
"destination_type": {
"type": "string",
- "description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph."
- }
+ "description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph.",
+ },
},
- "required": ["source", "destination", "relationship", "source_type", "destination_type"],
- "additionalProperties": False
- }
- }
+ "required": [
+ "source",
+ "destination",
+ "relationship",
+ "source_type",
+ "destination_type",
+ ],
+ "additionalProperties": False,
+ },
+ },
}
@@ -71,9 +76,9 @@ NOOP_TOOL = {
"type": "object",
"properties": {},
"required": [],
- "additionalProperties": False
- }
- }
+ "additionalProperties": False,
+ },
+ },
}
@@ -94,17 +99,23 @@ ADD_MESSAGE_TOOL = {
"source_type": {"type": "string"},
"relation": {"type": "string"},
"destination_node": {"type": "string"},
- "destination_type": {"type": "string"}
+ "destination_type": {"type": "string"},
},
- "required": ["source_node", "source_type", "relation", "destination_node", "destination_type"],
- "additionalProperties": False
- }
+ "required": [
+ "source_node",
+ "source_type",
+ "relation",
+ "destination_node",
+ "destination_type",
+ ],
+ "additionalProperties": False,
+ },
}
},
"required": ["entities"],
- "additionalProperties": False
- }
- }
+ "additionalProperties": False,
+ },
+ },
}
@@ -118,23 +129,19 @@ SEARCH_TOOL = {
"properties": {
"nodes": {
"type": "array",
- "items": {
- "type": "string"
- },
- "description": "List of nodes to search for."
+ "items": {"type": "string"},
+ "description": "List of nodes to search for.",
},
"relations": {
"type": "array",
- "items": {
- "type": "string"
- },
- "description": "List of relations to search for."
- }
+ "items": {"type": "string"},
+ "description": "List of relations to search for.",
+ },
},
"required": ["nodes", "relations"],
- "additionalProperties": False
- }
- }
+ "additionalProperties": False,
+ },
+ },
}
UPDATE_MEMORY_STRUCT_TOOL_GRAPH = {
@@ -148,21 +155,21 @@ UPDATE_MEMORY_STRUCT_TOOL_GRAPH = {
"properties": {
"source": {
"type": "string",
- "description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph."
+ "description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph.",
},
"destination": {
"type": "string",
- "description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph."
+ "description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph.",
},
"relationship": {
"type": "string",
- "description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
- }
+ "description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
+ },
},
"required": ["source", "destination", "relationship"],
- "additionalProperties": False
- }
- }
+ "additionalProperties": False,
+ },
+ },
}
ADD_MEMORY_STRUCT_TOOL_GRAPH = {
@@ -176,29 +183,35 @@ ADD_MEMORY_STRUCT_TOOL_GRAPH = {
"properties": {
"source": {
"type": "string",
- "description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created."
+ "description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created.",
},
"destination": {
"type": "string",
- "description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created."
+ "description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created.",
},
"relationship": {
"type": "string",
- "description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected."
+ "description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.",
},
"source_type": {
"type": "string",
- "description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph."
+ "description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph.",
},
"destination_type": {
"type": "string",
- "description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph."
- }
+ "description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph.",
+ },
},
- "required": ["source", "destination", "relationship", "source_type", "destination_type"],
- "additionalProperties": False
- }
- }
+ "required": [
+ "source",
+ "destination",
+ "relationship",
+ "source_type",
+ "destination_type",
+ ],
+ "additionalProperties": False,
+ },
+ },
}
@@ -212,9 +225,9 @@ NOOP_STRUCT_TOOL = {
"type": "object",
"properties": {},
"required": [],
- "additionalProperties": False
- }
- }
+ "additionalProperties": False,
+ },
+ },
}
@@ -236,17 +249,23 @@ ADD_MESSAGE_STRUCT_TOOL = {
"source_type": {"type": "string"},
"relation": {"type": "string"},
"destination_node": {"type": "string"},
- "destination_type": {"type": "string"}
+ "destination_type": {"type": "string"},
},
- "required": ["source_node", "source_type", "relation", "destination_node", "destination_type"],
- "additionalProperties": False
- }
+ "required": [
+ "source_node",
+ "source_type",
+ "relation",
+ "destination_node",
+ "destination_type",
+ ],
+ "additionalProperties": False,
+ },
}
},
"required": ["entities"],
- "additionalProperties": False
- }
- }
+ "additionalProperties": False,
+ },
+ },
}
@@ -261,21 +280,17 @@ SEARCH_STRUCT_TOOL = {
"properties": {
"nodes": {
"type": "array",
- "items": {
- "type": "string"
- },
- "description": "List of nodes to search for."
+ "items": {"type": "string"},
+ "description": "List of nodes to search for.",
},
"relations": {
"type": "array",
- "items": {
- "type": "string"
- },
- "description": "List of relations to search for."
- }
+ "items": {"type": "string"},
+ "description": "List of relations to search for.",
+ },
},
"required": ["nodes", "relations"],
- "additionalProperties": False
- }
- }
+ "additionalProperties": False,
+ },
+ },
}
diff --git a/mem0/graphs/utils.py b/mem0/graphs/utils.py
index e9ed827e..efc14db0 100644
--- a/mem0/graphs/utils.py
+++ b/mem0/graphs/utils.py
@@ -1,4 +1,3 @@
-
UPDATE_GRAPH_PROMPT = """
You are an AI expert specializing in graph memory management and optimization. Your task is to analyze existing graph memories alongside new information, and update the relationships in the memory list to ensure the most accurate, current, and coherent representation of knowledge.
@@ -55,10 +54,10 @@ Strive for a coherent, easily understandable knowledge graph by maintaining cons
Adhere strictly to these guidelines to ensure high-quality knowledge graph extraction."""
-
def get_update_memory_prompt(existing_memories, memory, template):
return template.format(existing_memories=existing_memories, memory=memory)
+
def get_update_memory_messages(existing_memories, memory):
return [
{
diff --git a/mem0/llms/anthropic.py b/mem0/llms/anthropic.py
index fb390348..5f004ae8 100644
--- a/mem0/llms/anthropic.py
+++ b/mem0/llms/anthropic.py
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional
try:
import anthropic
except ImportError:
- raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
+ raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
@@ -43,8 +43,8 @@ class AnthropicLLM(LLMBase):
system_message = ""
filtered_messages = []
for message in messages:
- if message['role'] == 'system':
- system_message = message['content']
+ if message["role"] == "system":
+ system_message = message["content"]
else:
filtered_messages.append(message)
@@ -56,7 +56,7 @@ class AnthropicLLM(LLMBase):
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
- if tools: # TODO: Remove tools if no issues found with new memory addition logic
+ if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
diff --git a/mem0/llms/aws_bedrock.py b/mem0/llms/aws_bedrock.py
index 5e7969c1..2bc963c2 100644
--- a/mem0/llms/aws_bedrock.py
+++ b/mem0/llms/aws_bedrock.py
@@ -125,9 +125,7 @@ class AWSBedrockLLM(LLMBase):
},
}
input_body["textGenerationConfig"] = {
- k: v
- for k, v in input_body["textGenerationConfig"].items()
- if v is not None
+ k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
}
return input_body
@@ -161,9 +159,7 @@ class AWSBedrockLLM(LLMBase):
}
}
- for prop, details in (
- function["parameters"].get("properties", {}).items()
- ):
+ for prop, details in function["parameters"].get("properties", {}).items():
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
"type": details.get("type", "string"),
"description": details.get("description", ""),
@@ -216,9 +212,7 @@ class AWSBedrockLLM(LLMBase):
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
- input_body = self._prepare_input(
- provider, self.config.model, prompt, **self.model_kwargs
- )
+ input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
body = json.dumps(input_body)
response = self.client.invoke_model(
diff --git a/mem0/llms/azure_openai.py b/mem0/llms/azure_openai.py
index f093284b..f1fe6863 100644
--- a/mem0/llms/azure_openai.py
+++ b/mem0/llms/azure_openai.py
@@ -15,20 +15,20 @@ class AzureOpenAILLM(LLMBase):
# Model name should match the custom deployment name chosen for it.
if not self.config.model:
self.config.model = "gpt-4o"
-
+
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY")
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
self.client = AzureOpenAI(
- azure_deployment=azure_deployment,
+ azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
- http_client=self.config.http_client
- )
-
+ http_client=self.config.http_client,
+ )
+
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
@@ -87,7 +87,7 @@ class AzureOpenAILLM(LLMBase):
}
if response_format:
params["response_format"] = response_format
- if tools: # TODO: Remove tools if no issues found with new memory addition logic
+ if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
diff --git a/mem0/llms/azure_openai_structured.py b/mem0/llms/azure_openai_structured.py
index 091f92e3..729523d8 100644
--- a/mem0/llms/azure_openai_structured.py
+++ b/mem0/llms/azure_openai_structured.py
@@ -1,11 +1,11 @@
-import os
import json
+import os
from typing import Dict, List, Optional
from openai import AzureOpenAI
-from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
+from mem0.llms.base import LLMBase
class AzureOpenAIStructuredLLM(LLMBase):
@@ -15,21 +15,21 @@ class AzureOpenAIStructuredLLM(LLMBase):
# Model name should match the custom deployment name chosen for it.
if not self.config.model:
self.config.model = "gpt-4o-2024-08-06"
-
+
api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
# Can display a warning if API version is of model and api-version
-
+
self.client = AzureOpenAI(
- azure_deployment=azure_deployment,
+ azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
- http_client=self.config.http_client
- )
-
+ http_client=self.config.http_client,
+ )
+
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
diff --git a/mem0/llms/configs.py b/mem0/llms/configs.py
index fb6dccbf..dcd5b8c7 100644
--- a/mem0/llms/configs.py
+++ b/mem0/llms/configs.py
@@ -4,12 +4,8 @@ from pydantic import BaseModel, Field, field_validator
class LlmConfig(BaseModel):
- provider: str = Field(
- description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
- )
- config: Optional[dict] = Field(
- description="Configuration for the specific LLM", default={}
- )
+ provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai")
+ config: Optional[dict] = Field(description="Configuration for the specific LLM", default={})
@field_validator("config")
def validate_config(cls, v, values):
@@ -23,7 +19,7 @@ class LlmConfig(BaseModel):
"litellm",
"azure_openai",
"openai_structured",
- "azure_openai_structured"
+ "azure_openai_structured",
):
return v
else:
diff --git a/mem0/llms/litellm.py b/mem0/llms/litellm.py
index bfe95130..d5896ff8 100644
--- a/mem0/llms/litellm.py
+++ b/mem0/llms/litellm.py
@@ -67,9 +67,7 @@ class LiteLLM(LLMBase):
str: The generated response.
"""
if not litellm.supports_function_calling(self.config.model):
- raise ValueError(
- f"Model '{self.config.model}' in litellm does not support function calling."
- )
+ raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
params = {
"model": self.config.model,
@@ -80,7 +78,7 @@ class LiteLLM(LLMBase):
}
if response_format:
params["response_format"] = response_format
- if tools: # TODO: Remove tools if no issues found with new memory addition logic
+ if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py
index eb6abc75..89bef986 100644
--- a/mem0/llms/openai.py
+++ b/mem0/llms/openai.py
@@ -100,7 +100,7 @@ class OpenAILLM(LLMBase):
if response_format:
params["response_format"] = response_format
- if tools: # TODO: Remove tools if no issues found with new memory addition logic
+ if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
diff --git a/mem0/llms/openai_structured.py b/mem0/llms/openai_structured.py
index 0625c1e8..4060afb8 100644
--- a/mem0/llms/openai_structured.py
+++ b/mem0/llms/openai_structured.py
@@ -1,6 +1,5 @@
-import os
import json
-
+import os
from typing import Dict, List, Optional
from openai import OpenAI
@@ -20,7 +19,6 @@ class OpenAIStructuredLLM(LLMBase):
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE")
self.client = OpenAI(api_key=api_key, base_url=base_url)
-
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
@@ -31,8 +29,8 @@ class OpenAIStructuredLLM(LLMBase):
Returns:
str or dict: The processed response.
- """
-
+ """
+
if tools:
processed_response = {
"content": response.choices[0].message.content,
@@ -52,7 +50,6 @@ class OpenAIStructuredLLM(LLMBase):
else:
return response.choices[0].message.content
-
def generate_response(
self,
@@ -87,4 +84,4 @@ class OpenAIStructuredLLM(LLMBase):
response = self.client.beta.chat.completions.parse(**params)
- return self._parse_response(response, tools)
\ No newline at end of file
+ return self._parse_response(response, tools)
diff --git a/mem0/llms/together.py b/mem0/llms/together.py
index 51ebac66..922a30d2 100644
--- a/mem0/llms/together.py
+++ b/mem0/llms/together.py
@@ -20,7 +20,7 @@ class TogetherLLM(LLMBase):
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
self.client = Together(api_key=api_key)
-
+
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
@@ -79,7 +79,7 @@ class TogetherLLM(LLMBase):
}
if response_format:
params["response_format"] = response_format
- if tools: # TODO: Remove tools if no issues found with new memory addition logic
+ if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
diff --git a/mem0/llms/utils/tools.py b/mem0/llms/utils/tools.py
index 64f93145..6857294f 100644
--- a/mem0/llms/utils/tools.py
+++ b/mem0/llms/utils/tools.py
@@ -7,11 +7,9 @@ ADD_MEMORY_TOOL = {
"description": "Add a memory",
"parameters": {
"type": "object",
- "properties": {
- "data": {"type": "string", "description": "Data to add to memory"}
- },
+ "properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
- "additionalProperties": False
+ "additionalProperties": False,
},
},
}
@@ -34,7 +32,7 @@ UPDATE_MEMORY_TOOL = {
},
},
"required": ["memory_id", "data"],
- "additionalProperties": False
+ "additionalProperties": False,
},
},
}
@@ -53,7 +51,7 @@ DELETE_MEMORY_TOOL = {
}
},
"required": ["memory_id"],
- "additionalProperties": False
+ "additionalProperties": False,
},
},
}
diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py
index 7cdeb025..13020f07 100644
--- a/mem0/memory/graph_memory.py
+++ b/mem0/memory/graph_memory.py
@@ -3,30 +3,28 @@ import logging
from langchain_community.graphs import Neo4jGraph
from rank_bm25 import BM25Okapi
-from mem0.graphs.tools import (
- ADD_MEMORY_TOOL_GRAPH,
- ADD_MESSAGE_TOOL,
- NOOP_TOOL,
- SEARCH_TOOL,
- UPDATE_MEMORY_TOOL_GRAPH,
- UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
- ADD_MEMORY_STRUCT_TOOL_GRAPH,
- NOOP_STRUCT_TOOL,
- ADD_MESSAGE_STRUCT_TOOL,
- SEARCH_STRUCT_TOOL
-)
-from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
+from mem0.graphs.tools import (ADD_MEMORY_STRUCT_TOOL_GRAPH,
+ ADD_MEMORY_TOOL_GRAPH, ADD_MESSAGE_STRUCT_TOOL,
+ ADD_MESSAGE_TOOL, NOOP_STRUCT_TOOL, NOOP_TOOL,
+ SEARCH_STRUCT_TOOL, SEARCH_TOOL,
+ UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
+ UPDATE_MEMORY_TOOL_GRAPH)
+from mem0.graphs.utils import (EXTRACT_ENTITIES_PROMPT,
+ get_update_memory_messages)
from mem0.utils.factory import EmbedderFactory, LlmFactory
logger = logging.getLogger(__name__)
+
class MemoryGraph:
def __init__(self, config):
self.config = config
- self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
- self.embedding_model = EmbedderFactory.create(
- self.config.embedder.provider, self.config.embedder.config
+ self.graph = Neo4jGraph(
+ self.config.graph_store.config.url,
+ self.config.graph_store.config.username,
+ self.config.graph_store.config.password,
)
+ self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
self.llm_provider = "openai_structured"
if self.config.llm.provider:
@@ -51,15 +49,23 @@ class MemoryGraph:
search_output = self._search(data, filters)
if self.config.graph_store.custom_prompt:
- messages=[
- {"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")},
+ messages = [
+ {
+ "role": "system",
+ "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace(
+ "CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
+ ),
+ },
{"role": "user", "content": data},
]
else:
- messages=[
- {"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
+ messages = [
+ {
+ "role": "system",
+ "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id),
+ },
{"role": "user", "content": data},
- ]
+ ]
_tools = [ADD_MESSAGE_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
@@ -67,11 +73,11 @@ class MemoryGraph:
extracted_entities = self.llm.generate_response(
messages=messages,
- tools = _tools,
+ tools=_tools,
)
- if extracted_entities['tool_calls']:
- extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
+ if extracted_entities["tool_calls"]:
+ extracted_entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
else:
extracted_entities = []
@@ -79,9 +85,13 @@ class MemoryGraph:
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
- _tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
- if self.llm_provider in ["azure_openai_structured","openai_structured"]:
- _tools = [UPDATE_MEMORY_STRUCT_TOOL_GRAPH, ADD_MEMORY_STRUCT_TOOL_GRAPH, NOOP_STRUCT_TOOL]
+ _tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
+ if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
+ _tools = [
+ UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
+ ADD_MEMORY_STRUCT_TOOL_GRAPH,
+ NOOP_STRUCT_TOOL,
+ ]
memory_updates = self.llm.generate_response(
messages=update_memory_prompt,
@@ -90,28 +100,29 @@ class MemoryGraph:
to_be_added = []
- for item in memory_updates['tool_calls']:
- if item['name'] == "add_graph_memory":
- to_be_added.append(item['arguments'])
- elif item['name'] == "update_graph_memory":
- self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'], filters)
- elif item['name'] == "noop":
+ for item in memory_updates["tool_calls"]:
+ if item["name"] == "add_graph_memory":
+ to_be_added.append(item["arguments"])
+ elif item["name"] == "update_graph_memory":
+ self._update_relationship(
+ item["arguments"]["source"],
+ item["arguments"]["destination"],
+ item["arguments"]["relationship"],
+ filters,
+ )
+ elif item["name"] == "noop":
continue
returned_entities = []
for item in to_be_added:
- source = item['source'].lower().replace(" ", "_")
- source_type = item['source_type'].lower().replace(" ", "_")
- relation = item['relationship'].lower().replace(" ", "_")
- destination = item['destination'].lower().replace(" ", "_")
- destination_type = item['destination_type'].lower().replace(" ", "_")
+ source = item["source"].lower().replace(" ", "_")
+ source_type = item["source_type"].lower().replace(" ", "_")
+ relation = item["relationship"].lower().replace(" ", "_")
+ destination = item["destination"].lower().replace(" ", "_")
+ destination_type = item["destination_type"].lower().replace(" ", "_")
- returned_entities.append({
- "source" : source,
- "relationship" : relation,
- "target" : destination
- })
+ returned_entities.append({"source": source, "relationship": relation, "target": destination})
# Create embeddings
source_embedding = self.embedding_model.embed(source)
@@ -135,7 +146,7 @@ class MemoryGraph:
"dest_name": destination,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding,
- "user_id": filters["user_id"]
+ "user_id": filters["user_id"],
}
_ = self.graph.query(cypher, params=params)
@@ -150,19 +161,22 @@ class MemoryGraph:
_tools = [SEARCH_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
- {"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
+ {
+ "role": "system",
+ "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities.",
+ },
{"role": "user", "content": query},
],
- tools = _tools
+ tools=_tools,
)
node_list = []
relation_list = []
- for item in search_results['tool_calls']:
- if item['name'] == "search":
+ for item in search_results["tool_calls"]:
+ if item["name"] == "search":
try:
- node_list.extend(item['arguments']['nodes'])
+ node_list.extend(item["arguments"]["nodes"])
except Exception as e:
logger.error(f"Error in search tool: {e}")
@@ -201,13 +215,16 @@ class MemoryGraph:
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
ORDER BY similarity DESC
"""
- params = {"n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"]}
+ params = {
+ "n_embedding": n_embedding,
+ "threshold": self.threshold,
+ "user_id": filters["user_id"],
+ }
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
return result_relations
-
def search(self, query, filters):
"""
Search for memories and related graph data.
@@ -235,17 +252,12 @@ class MemoryGraph:
search_results = []
for item in reranked_results:
- search_results.append({
- "source": item[0],
- "relationship": item[1],
- "target": item[2]
- })
+ search_results.append({"source": item[0], "relationship": item[1], "target": item[2]})
logger.info(f"Returned {len(search_results)} search results")
return search_results
-
def delete_all(self, filters):
cypher = """
MATCH (n {user_id: $user_id})
@@ -254,7 +266,6 @@ class MemoryGraph:
params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params)
-
def get_all(self, filters):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
@@ -276,17 +287,18 @@ class MemoryGraph:
final_results = []
for result in results:
- final_results.append({
- "source": result['source'],
- "relationship": result['relationship'],
- "target": result['target']
- })
+ final_results.append(
+ {
+ "source": result["source"],
+ "relationship": result["relationship"],
+ "target": result["target"],
+ }
+ )
logger.info(f"Retrieved {len(final_results)} relationships")
return final_results
-
-
+
def _update_relationship(self, source, target, relationship, filters):
"""
Update or create a relationship between two nodes in the graph.
@@ -309,14 +321,20 @@ class MemoryGraph:
MERGE (n1 {name: $source, user_id: $user_id})
MERGE (n2 {name: $target, user_id: $user_id})
"""
- self.graph.query(check_and_create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
+ self.graph.query(
+ check_and_create_query,
+ params={"source": source, "target": target, "user_id": filters["user_id"]},
+ )
# Delete any existing relationship between the nodes
delete_query = """
MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id})
DELETE r
"""
- self.graph.query(delete_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
+ self.graph.query(
+ delete_query,
+ params={"source": source, "target": target, "user_id": filters["user_id"]},
+ )
# Create the new relationship
create_query = f"""
@@ -324,7 +342,10 @@ class MemoryGraph:
CREATE (n1)-[r:{relationship}]->(n2)
RETURN n1, r, n2
"""
- result = self.graph.query(create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
+ result = self.graph.query(
+ create_query,
+ params={"source": source, "target": target, "user_id": filters["user_id"]},
+ )
if not result:
raise Exception(f"Failed to update or create relationship between {source} and {target}")
diff --git a/mem0/memory/main.py b/mem0/memory/main.py
index 5516227c..8a0cc1ac 100644
--- a/mem0/memory/main.py
+++ b/mem0/memory/main.py
@@ -10,14 +10,14 @@ from typing import Any, Dict
import pytz
from pydantic import ValidationError
+from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
-from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
-from mem0.configs.base import MemoryItem, MemoryConfig
+from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
# Setup user config
setup_config()
@@ -30,9 +30,7 @@ class Memory(MemoryBase):
self.config = config
self.custom_prompt = self.config.custom_prompt
- self.embedding_model = EmbedderFactory.create(
- self.config.embedder.provider, self.config.embedder.config
- )
+ self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
@@ -45,12 +43,12 @@ class Memory(MemoryBase):
if self.version == "v1.1" and self.config.graph_store.config:
from mem0.memory.graph_memory import MemoryGraph
+
self.graph = MemoryGraph(self.config)
self.enable_graph = True
capture_event("mem0.init", self)
-
@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
try:
@@ -60,7 +58,6 @@ class Memory(MemoryBase):
raise
return cls(config)
-
def add(
self,
messages,
@@ -98,9 +95,7 @@ class Memory(MemoryBase):
filters["run_id"] = metadata["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
- raise ValueError(
- "One of the filters: user_id, agent_id or run_id is required!"
- )
+ raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
@@ -116,8 +111,8 @@ class Memory(MemoryBase):
if self.version == "v1.1":
return {
- "results" : vector_store_result,
- "relations" : graph_result,
+ "results": vector_store_result,
+ "relations": graph_result,
}
else:
warnings.warn(
@@ -125,29 +120,29 @@ class Memory(MemoryBase):
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
)
return {"message": "ok"}
-
def _add_to_vector_store(self, messages, metadata, filters):
parsed_messages = parse_messages(messages)
if self.custom_prompt:
- system_prompt=self.custom_prompt
- user_prompt=f"Input: {parsed_messages}"
+ system_prompt = self.custom_prompt
+ user_prompt = f"Input: {parsed_messages}"
else:
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
response = self.llm.generate_response(
- messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
+ messages=[
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ],
response_format={"type": "json_object"},
)
try:
- new_retrieved_facts = json.loads(response)[
- "facts"
- ]
+ new_retrieved_facts = json.loads(response)["facts"]
except Exception as e:
logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []
@@ -178,24 +173,30 @@ class Memory(MemoryBase):
logging.info(resp)
try:
if resp["event"] == "ADD":
- memory_id = self._create_memory(data=resp["text"], metadata=metadata)
- returned_memories.append({
- "memory" : resp["text"],
- "event" : resp["event"],
- })
+ _ = self._create_memory(data=resp["text"], metadata=metadata)
+ returned_memories.append(
+ {
+ "memory": resp["text"],
+ "event": resp["event"],
+ }
+ )
elif resp["event"] == "UPDATE":
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
- returned_memories.append({
- "memory" : resp["text"],
- "event" : resp["event"],
- "previous_memory" : resp["old_memory"],
- })
+ returned_memories.append(
+ {
+ "memory": resp["text"],
+ "event": resp["event"],
+ "previous_memory": resp["old_memory"],
+ }
+ )
elif resp["event"] == "DELETE":
self._delete_memory(memory_id=resp["id"])
- returned_memories.append({
- "memory" : resp["text"],
- "event" : resp["event"],
- })
+ returned_memories.append(
+ {
+ "memory": resp["text"],
+ "event": resp["event"],
+ }
+ )
elif resp["event"] == "NONE":
logging.info("NOOP for Memory.")
except Exception as e:
@@ -206,7 +207,6 @@ class Memory(MemoryBase):
capture_event("mem0.add", self)
return returned_memories
-
def _add_to_graph(self, messages, filters):
added_entities = []
@@ -220,7 +220,6 @@ class Memory(MemoryBase):
return added_entities
-
def get(self, memory_id):
"""
Retrieve a memory by ID.
@@ -236,11 +235,7 @@ class Memory(MemoryBase):
if not memory:
return None
- filters = {
- key: memory.payload[key]
- for key in ["user_id", "agent_id", "run_id"]
- if memory.payload.get(key)
- }
+ filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
# Prepare base memory item
memory_item = MemoryItem(
@@ -261,9 +256,7 @@ class Memory(MemoryBase):
"created_at",
"updated_at",
}
- additional_metadata = {
- k: v for k, v in memory.payload.items() if k not in excluded_keys
- }
+ additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
if additional_metadata:
memory_item["metadata"] = additional_metadata
@@ -271,7 +264,6 @@ class Memory(MemoryBase):
return result
-
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
"""
List all memories.
@@ -288,10 +280,12 @@ class Memory(MemoryBase):
filters["run_id"] = run_id
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
-
+
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
- future_graph_entities = executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
+ future_graph_entities = (
+ executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
+ )
all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
@@ -307,15 +301,22 @@ class Memory(MemoryBase):
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
)
return all_memories
-
def _get_all_from_vector_store(self, filters, limit):
memories = self.vector_store.list(filters=filters, limit=limit)
- excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
+ excluded_keys = {
+ "user_id",
+ "agent_id",
+ "run_id",
+ "hash",
+ "data",
+ "created_at",
+ "updated_at",
+ }
all_memories = [
{
**MemoryItem(
@@ -325,19 +326,9 @@ class Memory(MemoryBase):
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"}),
- **{
- key: mem.payload[key]
- for key in ["user_id", "agent_id", "run_id"]
- if key in mem.payload
- },
+ **{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
- {
- "metadata": {
- k: v
- for k, v in mem.payload.items()
- if k not in excluded_keys
- }
- }
+ {"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
@@ -346,10 +337,7 @@ class Memory(MemoryBase):
]
return all_memories
-
- def search(
- self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
- ):
+ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None):
"""
Search for memories.
@@ -373,15 +361,21 @@ class Memory(MemoryBase):
filters["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
- raise ValueError(
- "One of the filters: user_id, agent_id or run_id is required!"
- )
+ raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
- capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
+ capture_event(
+ "mem0.search",
+ self,
+ {"filters": len(filters), "limit": limit, "version": self.version},
+ )
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
- future_graph_entities = executor.submit(self.graph.search, query, filters) if self.version == "v1.1" and self.enable_graph else None
+ future_graph_entities = (
+ executor.submit(self.graph.search, query, filters)
+ if self.version == "v1.1" and self.enable_graph
+ else None
+ )
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
@@ -390,23 +384,20 @@ class Memory(MemoryBase):
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
else:
- return {"results" : original_memories}
+ return {"results": original_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
- stacklevel=2
+ stacklevel=2,
)
return original_memories
-
def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query)
- memories = self.vector_store.search(
- query=embeddings, limit=limit, filters=filters
- )
+ memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters)
excluded_keys = {
"user_id",
@@ -428,19 +419,9 @@ class Memory(MemoryBase):
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump(),
- **{
- key: mem.payload[key]
- for key in ["user_id", "agent_id", "run_id"]
- if key in mem.payload
- },
+ **{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
- {
- "metadata": {
- k: v
- for k, v in mem.payload.items()
- if k not in excluded_keys
- }
- }
+ {"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
@@ -450,7 +431,6 @@ class Memory(MemoryBase):
return original_memories
-
def update(self, memory_id, data):
"""
Update a memory by ID.
@@ -466,7 +446,6 @@ class Memory(MemoryBase):
self._update_memory(memory_id, data)
return {"message": "Memory updated successfully!"}
-
def delete(self, memory_id):
"""
Delete a memory by ID.
@@ -478,7 +457,6 @@ class Memory(MemoryBase):
self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}
-
def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
Delete all memories.
@@ -511,8 +489,7 @@ class Memory(MemoryBase):
if self.version == "v1.1" and self.enable_graph:
self.graph.delete_all(filters)
- return {'message': 'Memories deleted successfully!'}
-
+ return {"message": "Memories deleted successfully!"}
def history(self, memory_id):
"""
@@ -527,7 +504,6 @@ class Memory(MemoryBase):
capture_event("mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)
-
def _create_memory(self, data, metadata=None):
logging.info(f"Creating memory with {data=}")
embeddings = self.embedding_model.embed(data)
@@ -542,12 +518,9 @@ class Memory(MemoryBase):
ids=[memory_id],
payloads=[metadata],
)
- self.db.add_history(
- memory_id, None, data, "ADD", created_at=metadata["created_at"]
- )
+ self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
return memory_id
-
def _update_memory(self, memory_id, data, metadata=None):
logger.info(f"Updating memory with {data=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
@@ -557,9 +530,7 @@ class Memory(MemoryBase):
new_metadata["data"] = data
new_metadata["hash"] = existing_memory.payload.get("hash")
new_metadata["created_at"] = existing_memory.payload.get("created_at")
- new_metadata["updated_at"] = datetime.now(
- pytz.timezone("US/Pacific")
- ).isoformat()
+ new_metadata["updated_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
if "user_id" in existing_memory.payload:
new_metadata["user_id"] = existing_memory.payload["user_id"]
@@ -584,7 +555,6 @@ class Memory(MemoryBase):
updated_at=new_metadata["updated_at"],
)
-
def _delete_memory(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
@@ -592,7 +562,6 @@ class Memory(MemoryBase):
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
-
def reset(self):
"""
Reset the memory store.
@@ -602,6 +571,5 @@ class Memory(MemoryBase):
self.db.reset()
capture_event("mem0.reset", self)
-
def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")
diff --git a/mem0/memory/storage.py b/mem0/memory/storage.py
index 126df85d..87a256dc 100644
--- a/mem0/memory/storage.py
+++ b/mem0/memory/storage.py
@@ -12,9 +12,7 @@ class SQLiteManager:
with self.connection:
cursor = self.connection.cursor()
- cursor.execute(
- "SELECT name FROM sqlite_master WHERE type='table' AND name='history'"
- )
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
table_exists = cursor.fetchone() is not None
if table_exists:
@@ -62,7 +60,7 @@ class SQLiteManager:
INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted)
SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted
FROM old_history
- """
+ """ # noqa: E501
)
cursor.execute("DROP TABLE old_history")
diff --git a/mem0/memory/telemetry.py b/mem0/memory/telemetry.py
index 9b78d775..5cab2b1f 100644
--- a/mem0/memory/telemetry.py
+++ b/mem0/memory/telemetry.py
@@ -1,7 +1,7 @@
import logging
+import os
import platform
import sys
-import os
from posthog import Posthog
@@ -15,8 +15,9 @@ if isinstance(MEM0_TELEMETRY, str):
if not isinstance(MEM0_TELEMETRY, bool):
raise ValueError("MEM0_TELEMETRY must be a boolean value.")
-logging.getLogger('posthog').setLevel(logging.CRITICAL + 1)
-logging.getLogger('urllib3').setLevel(logging.CRITICAL + 1)
+logging.getLogger("posthog").setLevel(logging.CRITICAL + 1)
+logging.getLogger("urllib3").setLevel(logging.CRITICAL + 1)
+
class AnonymousTelemetry:
def __init__(self, project_api_key, host):
@@ -24,9 +25,8 @@ class AnonymousTelemetry:
# Call setup config to ensure that the user_id is generated
setup_config()
self.user_id = get_user_id()
- # Optional
- if not MEM0_TELEMETRY:
- self.posthog.disabled = True
+ if not MEM0_TELEMETRY:
+ self.posthog.disabled = True
def capture_event(self, event_name, properties=None):
if properties is None:
@@ -40,9 +40,7 @@ class AnonymousTelemetry:
"machine": platform.machine(),
**properties,
}
- self.posthog.capture(
- distinct_id=self.user_id, event=event_name, properties=properties
- )
+ self.posthog.capture(distinct_id=self.user_id, event=event_name, properties=properties)
def identify_user(self, user_id, properties=None):
if properties is None:
@@ -65,6 +63,7 @@ def capture_event(event_name, memory_instance, additional_data=None):
"collection": memory_instance.collection_name,
"vector_size": memory_instance.embedding_model.config.embedding_dims,
"history_store": "sqlite",
+ "graph_store": f"{memory_instance.graph.__class__.__module__}.{memory_instance.graph.__class__.__name__}" if memory_instance.config.graph_store.config else None,
"vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
"llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
"embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
@@ -76,7 +75,6 @@ def capture_event(event_name, memory_instance, additional_data=None):
telemetry.capture_event(event_name, event_data)
-
def capture_client_event(event_name, instance, additional_data=None):
event_data = {
"function": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
diff --git a/mem0/memory/utils.py b/mem0/memory/utils.py
index a0c82fed..a7e7bc35 100644
--- a/mem0/memory/utils.py
+++ b/mem0/memory/utils.py
@@ -4,13 +4,14 @@ from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
def get_fact_retrieval_messages(message):
return FACT_RETRIEVAL_PROMPT, f"Input: {message}"
+
def parse_messages(messages):
- response = ""
- for msg in messages:
- if msg["role"] == "system":
- response += f"system: {msg['content']}\n"
- if msg["role"] == "user":
- response += f"user: {msg['content']}\n"
- if msg["role"] == "assistant":
- response += f"assistant: {msg['content']}\n"
- return response
+ response = ""
+ for msg in messages:
+ if msg["role"] == "system":
+ response += f"system: {msg['content']}\n"
+ if msg["role"] == "user":
+ response += f"user: {msg['content']}\n"
+ if msg["role"] == "assistant":
+ response += f"assistant: {msg['content']}\n"
+ return response
diff --git a/mem0/proxy/main.py b/mem0/proxy/main.py
index bb614f4f..b13c681e 100644
--- a/mem0/proxy/main.py
+++ b/mem0/proxy/main.py
@@ -10,7 +10,7 @@ try:
import litellm
except ImportError:
user_input = input("The 'litellm' library is required. Install it now? [y/N]: ")
- if user_input.lower() == 'y':
+ if user_input.lower() == "y":
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"])
import litellm
@@ -105,16 +105,10 @@ class Completions:
prepared_messages = self._prepare_messages(messages)
if prepared_messages[-1]["role"] == "user":
- self._async_add_to_memory(
- messages, user_id, agent_id, run_id, metadata, filters
- )
- relevant_memories = self._fetch_relevant_memories(
- messages, user_id, agent_id, run_id, filters, limit
- )
+ self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
+ relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
logger.debug(f"Retrieved {len(relevant_memories)} relevant memories")
- prepared_messages[-1]["content"] = self._format_query_with_memories(
- messages, relevant_memories
- )
+ prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
response = litellm.completion(
model=model,
@@ -156,9 +150,7 @@ class Completions:
messages[0]["content"] = MEMORY_ANSWER_PROMPT
return messages
- def _async_add_to_memory(
- self, messages, user_id, agent_id, run_id, metadata, filters
- ):
+ def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
def add_task():
logger.debug("Adding to memory asynchronously")
self.mem0_client.add(
@@ -172,13 +164,9 @@ class Completions:
threading.Thread(target=add_task, daemon=True).start()
- def _fetch_relevant_memories(
- self, messages, user_id, agent_id, run_id, filters, limit
- ):
+ def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
# Currently, only pass the last 6 messages to the search API to prevent long query
- message_input = [
- f"{message['role']}: {message['content']}" for message in messages
- ][-6:]
+ message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
# TODO: Make it better by summarizing the past conversation
return self.mem0_client.search(
query="\n".join(message_input),
diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py
index 7047febb..21c04459 100644
--- a/mem0/utils/factory.py
+++ b/mem0/utils/factory.py
@@ -21,7 +21,7 @@ class LlmFactory:
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
"anthropic": "mem0.llms.anthropic.AnthropicLLM",
- "azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM"
+ "azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM",
}
@classmethod
@@ -59,7 +59,7 @@ class VectorStoreFactory:
"qdrant": "mem0.vector_stores.qdrant.Qdrant",
"chroma": "mem0.vector_stores.chroma.ChromaDB",
"pgvector": "mem0.vector_stores.pgvector.PGVector",
- "milvus": "mem0.vector_stores.milvus.MilvusDB"
+ "milvus": "mem0.vector_stores.milvus.MilvusDB",
}
@classmethod
diff --git a/mem0/vector_stores/chroma.py b/mem0/vector_stores/chroma.py
index 0dc97a3f..efb9fddb 100644
--- a/mem0/vector_stores/chroma.py
+++ b/mem0/vector_stores/chroma.py
@@ -80,24 +80,14 @@ class ChromaDB(VectorStoreBase):
values.append(value)
ids, distances, metadatas = values
- max_length = max(
- len(v) for v in values if isinstance(v, list) and v is not None
- )
+ max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
result = []
for i in range(max_length):
entry = OutputData(
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
- score=(
- distances[i]
- if isinstance(distances, list) and distances and i < len(distances)
- else None
- ),
- payload=(
- metadatas[i]
- if isinstance(metadatas, list) and metadatas and i < len(metadatas)
- else None
- ),
+ score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
+ payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
)
result.append(entry)
@@ -143,9 +133,7 @@ class ChromaDB(VectorStoreBase):
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
- def search(
- self, query: List[list], limit: int = 5, filters: Optional[Dict] = None
- ) -> List[OutputData]:
+ def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
"""
Search for similar vectors.
@@ -157,9 +145,7 @@ class ChromaDB(VectorStoreBase):
Returns:
List[OutputData]: Search results.
"""
- results = self.collection.query(
- query_embeddings=query, where=filters, n_results=limit
- )
+ results = self.collection.query(query_embeddings=query, where=filters, n_results=limit)
final_results = self._parse_output(results)
return final_results
@@ -225,9 +211,7 @@ class ChromaDB(VectorStoreBase):
"""
return self.client.get_collection(name=self.collection_name)
- def list(
- self, filters: Optional[Dict] = None, limit: int = 100
- ) -> List[OutputData]:
+ def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
"""
List all vectors in a collection.
diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py
index d4cd6b13..65e55a53 100644
--- a/mem0/vector_stores/configs.py
+++ b/mem0/vector_stores/configs.py
@@ -8,15 +8,13 @@ class VectorStoreConfig(BaseModel):
description="Provider of the vector store (e.g., 'qdrant', 'chroma')",
default="qdrant",
)
- config: Optional[Dict] = Field(
- description="Configuration for the specific vector store", default=None
- )
+ config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None)
_provider_configs: Dict[str, str] = {
"qdrant": "QdrantConfig",
"chroma": "ChromaDbConfig",
"pgvector": "PGVectorConfig",
- "milvus" : "MilvusDBConfig"
+ "milvus": "MilvusDBConfig",
}
@model_validator(mode="after")
diff --git a/mem0/vector_stores/milvus.py b/mem0/vector_stores/milvus.py
index eeed7ac6..e1df3458 100644
--- a/mem0/vector_stores/milvus.py
+++ b/mem0/vector_stores/milvus.py
@@ -1,15 +1,17 @@
import logging
+from typing import Dict, Optional
+
from pydantic import BaseModel
-from typing import Optional, Dict
-from mem0.vector_stores.base import VectorStoreBase
+
from mem0.configs.vector_stores.milvus import MetricType
+from mem0.vector_stores.base import VectorStoreBase
try:
- import pymilvus
+ import pymilvus # noqa: F401
except ImportError:
raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.")
-from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
+from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
logger = logging.getLogger(__name__)
@@ -20,9 +22,15 @@ class OutputData(BaseModel):
payload: Optional[Dict] # metadata
-
class MilvusDB(VectorStoreBase):
- def __init__(self, url: str, token: str, collection_name: str, embedding_model_dims: int, metric_type: MetricType) -> None:
+ def __init__(
+ self,
+ url: str,
+ token: str,
+ collection_name: str,
+ embedding_model_dims: int,
+ metric_type: MetricType,
+ ) -> None:
"""Initialize the MilvusDB database.
Args:
@@ -32,22 +40,21 @@ class MilvusDB(VectorStoreBase):
embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536).
metric_type (MetricType): Metric type for similarity search (defaults to L2).
"""
-
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.metric_type = metric_type
-
- self.client = MilvusClient(uri=url,token=token)
-
+ self.client = MilvusClient(uri=url, token=token)
self.create_col(
collection_name=self.collection_name,
vector_size=self.embedding_model_dims,
- metric_type=self.metric_type
+ metric_type=self.metric_type,
)
-
-
+
def create_col(
- self, collection_name : str, vector_size : str, metric_type : MetricType = MetricType.COSINE
+ self,
+ collection_name: str,
+ vector_size: str,
+ metric_type: MetricType = MetricType.COSINE,
) -> None:
"""Create a new collection with index_type AUTOINDEX.
@@ -65,7 +72,7 @@ class MilvusDB(VectorStoreBase):
FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
FieldSchema(name="metadata", dtype=DataType.JSON),
]
-
+
schema = CollectionSchema(fields, enable_dynamic_field=True)
index = self.client.prepare_index_params(
@@ -73,12 +80,10 @@ class MilvusDB(VectorStoreBase):
metric_type=metric_type,
index_type="AUTOINDEX",
index_name="vector_index",
- params={ "nlist": 128 }
+ params={"nlist": 128},
)
-
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
-
-
+
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
"""Insert vectors into a collection.
@@ -91,9 +96,8 @@ class MilvusDB(VectorStoreBase):
data = {"id": idx, "vectors": embedding, "metadata": metadata}
self.client.insert(collection_name=self.collection_name, data=data, **kwargs)
-
def _create_filter(self, filters: dict):
- """Prepare filters for efficient query.
+ """Prepare filters for efficient query.
Args:
filters (dict): filters [user_id, agent_id, run_id]
@@ -109,8 +113,7 @@ class MilvusDB(VectorStoreBase):
operands.append(f'(metadata["{key}"] == {value})')
return " and ".join(operands)
-
-
+
def _parse_output(self, data: list):
"""
Parse the output data.
@@ -125,16 +128,15 @@ class MilvusDB(VectorStoreBase):
for value in data:
uid, score, metadata = (
- value.get("id"),
- value.get("distance"),
- value.get("entity",{}).get("metadata")
+ value.get("id"),
+ value.get("distance"),
+ value.get("entity", {}).get("metadata"),
)
-
+
memory_obj = OutputData(id=uid, score=score, payload=metadata)
memory.append(memory_obj)
return memory
-
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
"""
@@ -150,14 +152,15 @@ class MilvusDB(VectorStoreBase):
"""
query_filter = self._create_filter(filters) if filters else None
hits = self.client.search(
- collection_name=self.collection_name,
- data=[query], limit=limit, filter=query_filter,
- output_fields=["*"]
+ collection_name=self.collection_name,
+ data=[query],
+ limit=limit,
+ filter=query_filter,
+ output_fields=["*"],
)
result = self._parse_output(data=hits[0])
-
return result
-
+
def delete(self, vector_id):
"""
Delete a vector by ID.
@@ -166,7 +169,6 @@ class MilvusDB(VectorStoreBase):
vector_id (str): ID of the vector to delete.
"""
self.client.delete(collection_name=self.collection_name, ids=vector_id)
-
def update(self, vector_id=None, vector=None, payload=None):
"""
@@ -177,7 +179,7 @@ class MilvusDB(VectorStoreBase):
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
- schema = {"id" : vector_id, "vectors": vector, "metadata" : payload}
+ schema = {"id": vector_id, "vectors": vector, "metadata": payload}
self.client.upsert(collection_name=self.collection_name, data=schema)
def get(self, vector_id):
@@ -191,7 +193,11 @@ class MilvusDB(VectorStoreBase):
OutputData: Retrieved vector.
"""
result = self.client.get(collection_name=self.collection_name, ids=vector_id)
- output = OutputData(id=result[0].get("id", None), score=None, payload=result[0].get("metadata", None))
+ output = OutputData(
+ id=result[0].get("id", None),
+ score=None,
+ payload=result[0].get("metadata", None),
+ )
return output
def list_cols(self):
@@ -228,12 +234,9 @@ class MilvusDB(VectorStoreBase):
List[OutputData]: List of vectors.
"""
query_filter = self._create_filter(filters) if filters else None
- result = self.client.query(
- collection_name=self.collection_name,
- filter=query_filter,
- limit=limit)
+ result = self.client.query(collection_name=self.collection_name, filter=query_filter, limit=limit)
memories = []
for data in result:
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
memories.append(obj)
- return [memories]
\ No newline at end of file
+ return [memories]
diff --git a/mem0/vector_stores/pgvector.py b/mem0/vector_stores/pgvector.py
index f9ec3f97..c8893e37 100644
--- a/mem0/vector_stores/pgvector.py
+++ b/mem0/vector_stores/pgvector.py
@@ -14,6 +14,7 @@ from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
+
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
@@ -22,7 +23,15 @@ class OutputData(BaseModel):
class PGVector(VectorStoreBase):
def __init__(
- self, dbname, collection_name, embedding_model_dims, user, password, host, port, diskann
+ self,
+ dbname,
+ collection_name,
+ embedding_model_dims,
+ user,
+ password,
+ host,
+ port,
+ diskann,
):
"""
Initialize the PGVector database.
@@ -40,9 +49,7 @@ class PGVector(VectorStoreBase):
self.collection_name = collection_name
self.use_diskann = diskann
- self.conn = psycopg2.connect(
- dbname=dbname, user=user, password=password, host=host, port=port
- )
+ self.conn = psycopg2.connect(dbname=dbname, user=user, password=password, host=host, port=port)
self.cur = self.conn.cursor()
collections = self.list_cols()
@@ -73,7 +80,8 @@ class PGVector(VectorStoreBase):
self.cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
if self.cur.fetchone():
# Create DiskANN index if extension is installed for faster search
- self.cur.execute(f"""
+ self.cur.execute(
+ f"""
CREATE INDEX IF NOT EXISTS {self.collection_name}_vector_idx
ON {self.collection_name}
USING diskann (vector);
@@ -94,10 +102,7 @@ class PGVector(VectorStoreBase):
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
json_payloads = [json.dumps(payload) for payload in payloads]
- data = [
- (id, vector, payload)
- for id, vector, payload in zip(ids, vectors, json_payloads)
- ]
+ data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)]
execute_values(
self.cur,
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
@@ -125,9 +130,7 @@ class PGVector(VectorStoreBase):
filter_conditions.append("payload->>%s = %s")
filter_params.extend([k, str(v)])
- filter_clause = (
- "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
- )
+ filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
self.cur.execute(
f"""
@@ -137,13 +140,11 @@ class PGVector(VectorStoreBase):
ORDER BY distance
LIMIT %s
""",
- (query, *filter_params, limit),
+ (query, *filter_params, limit),
)
results = self.cur.fetchall()
- return [
- OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results
- ]
+ return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results]
def delete(self, vector_id):
"""
@@ -152,9 +153,7 @@ class PGVector(VectorStoreBase):
Args:
vector_id (str): ID of the vector to delete.
"""
- self.cur.execute(
- f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,)
- )
+ self.cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
self.conn.commit()
def update(self, vector_id, vector=None, payload=None):
@@ -204,9 +203,7 @@ class PGVector(VectorStoreBase):
Returns:
List[str]: List of collection names.
"""
- self.cur.execute(
- "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
- )
+ self.cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
return [row[0] for row in self.cur.fetchall()]
def delete_col(self):
@@ -254,9 +251,7 @@ class PGVector(VectorStoreBase):
filter_conditions.append("payload->>%s = %s")
filter_params.extend([k, str(v)])
- filter_clause = (
- "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
- )
+ filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
query = f"""
SELECT id, vector, payload
diff --git a/mem0/vector_stores/qdrant.py b/mem0/vector_stores/qdrant.py
index 3ecb93f9..0afcb6c4 100644
--- a/mem0/vector_stores/qdrant.py
+++ b/mem0/vector_stores/qdrant.py
@@ -3,16 +3,9 @@ import os
import shutil
from qdrant_client import QdrantClient
-from qdrant_client.models import (
- Distance,
- FieldCondition,
- Filter,
- MatchValue,
- PointIdsList,
- PointStruct,
- Range,
- VectorParams,
-)
+from qdrant_client.models import (Distance, FieldCondition, Filter, MatchValue,
+ PointIdsList, PointStruct, Range,
+ VectorParams)
from mem0.vector_stores.base import VectorStoreBase
@@ -68,9 +61,7 @@ class Qdrant(VectorStoreBase):
self.collection_name = collection_name
self.create_col(embedding_model_dims, on_disk)
- def create_col(
- self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE
- ):
+ def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
"""
Create a new collection.
@@ -83,16 +74,12 @@ class Qdrant(VectorStoreBase):
response = self.list_cols()
for collection in response.collections:
if collection.name == self.collection_name:
- logging.debug(
- f"Collection {self.collection_name} already exists. Skipping creation."
- )
+ logging.debug(f"Collection {self.collection_name} already exists. Skipping creation.")
return
self.client.create_collection(
collection_name=self.collection_name,
- vectors_config=VectorParams(
- size=vector_size, distance=distance, on_disk=on_disk
- ),
+ vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk),
)
def insert(self, vectors: list, payloads: list = None, ids: list = None):
@@ -128,15 +115,9 @@ class Qdrant(VectorStoreBase):
conditions = []
for key, value in filters.items():
if isinstance(value, dict) and "gte" in value and "lte" in value:
- conditions.append(
- FieldCondition(
- key=key, range=Range(gte=value["gte"], lte=value["lte"])
- )
- )
+ conditions.append(FieldCondition(key=key, range=Range(gte=value["gte"], lte=value["lte"])))
else:
- conditions.append(
- FieldCondition(key=key, match=MatchValue(value=value))
- )
+ conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
return Filter(must=conditions) if conditions else None
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
@@ -196,9 +177,7 @@ class Qdrant(VectorStoreBase):
Returns:
dict: Retrieved vector.
"""
- result = self.client.retrieve(
- collection_name=self.collection_name, ids=[vector_id], with_payload=True
- )
+ result = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id], with_payload=True)
return result[0] if result else None
def list_cols(self) -> list:
diff --git a/poetry.lock b/poetry.lock
index 5e9aedb4..0360b809 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@@ -1623,28 +1623,29 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "ruff"
-version = "0.4.10"
+version = "0.6.5"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
- {file = "ruff-0.4.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c2c4d0859305ac5a16310eec40e4e9a9dec5dcdfbe92697acd99624e8638dac"},
- {file = "ruff-0.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a79489607d1495685cdd911a323a35871abfb7a95d4f98fc6f85e799227ac46e"},
- {file = "ruff-0.4.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1dd1681dfa90a41b8376a61af05cc4dc5ff32c8f14f5fe20dba9ff5deb80cd6"},
- {file = "ruff-0.4.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c75c53bb79d71310dc79fb69eb4902fba804a81f374bc86a9b117a8d077a1784"},
- {file = "ruff-0.4.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18238c80ee3d9100d3535d8eb15a59c4a0753b45cc55f8bf38f38d6a597b9739"},
- {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d8f71885bce242da344989cae08e263de29752f094233f932d4f5cfb4ef36a81"},
- {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:330421543bd3222cdfec481e8ff3460e8702ed1e58b494cf9d9e4bf90db52b9d"},
- {file = "ruff-0.4.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e9b6fb3a37b772628415b00c4fc892f97954275394ed611056a4b8a2631365e"},
- {file = "ruff-0.4.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f54c481b39a762d48f64d97351048e842861c6662d63ec599f67d515cb417f6"},
- {file = "ruff-0.4.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:67fe086b433b965c22de0b4259ddfe6fa541c95bf418499bedb9ad5fb8d1c631"},
- {file = "ruff-0.4.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:acfaaab59543382085f9eb51f8e87bac26bf96b164839955f244d07125a982ef"},
- {file = "ruff-0.4.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3cea07079962b2941244191569cf3a05541477286f5cafea638cd3aa94b56815"},
- {file = "ruff-0.4.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:338a64ef0748f8c3a80d7f05785930f7965d71ca260904a9321d13be24b79695"},
- {file = "ruff-0.4.10-py3-none-win32.whl", hash = "sha256:ffe3cd2f89cb54561c62e5fa20e8f182c0a444934bf430515a4b422f1ab7b7ca"},
- {file = "ruff-0.4.10-py3-none-win_amd64.whl", hash = "sha256:67f67cef43c55ffc8cc59e8e0b97e9e60b4837c8f21e8ab5ffd5d66e196e25f7"},
- {file = "ruff-0.4.10-py3-none-win_arm64.whl", hash = "sha256:dd1fcee327c20addac7916ca4e2653fbbf2e8388d8a6477ce5b4e986b68ae6c0"},
- {file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"},
+ {file = "ruff-0.6.5-py3-none-linux_armv6l.whl", hash = "sha256:7e4e308f16e07c95fc7753fc1aaac690a323b2bb9f4ec5e844a97bb7fbebd748"},
+ {file = "ruff-0.6.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:932cd69eefe4daf8c7d92bd6689f7e8182571cb934ea720af218929da7bd7d69"},
+ {file = "ruff-0.6.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3a8d42d11fff8d3143ff4da41742a98f8f233bf8890e9fe23077826818f8d680"},
+ {file = "ruff-0.6.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a50af6e828ee692fb10ff2dfe53f05caecf077f4210fae9677e06a808275754f"},
+ {file = "ruff-0.6.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:794ada3400a0d0b89e3015f1a7e01f4c97320ac665b7bc3ade24b50b54cb2972"},
+ {file = "ruff-0.6.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:381413ec47f71ce1d1c614f7779d88886f406f1fd53d289c77e4e533dc6ea200"},
+ {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:52e75a82bbc9b42e63c08d22ad0ac525117e72aee9729a069d7c4f235fc4d276"},
+ {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09c72a833fd3551135ceddcba5ebdb68ff89225d30758027280968c9acdc7810"},
+ {file = "ruff-0.6.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:800c50371bdcb99b3c1551d5691e14d16d6f07063a518770254227f7f6e8c178"},
+ {file = "ruff-0.6.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e25ddd9cd63ba1f3bd51c1f09903904a6adf8429df34f17d728a8fa11174253"},
+ {file = "ruff-0.6.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7291e64d7129f24d1b0c947ec3ec4c0076e958d1475c61202497c6aced35dd19"},
+ {file = "ruff-0.6.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9ad7dfbd138d09d9a7e6931e6a7e797651ce29becd688be8a0d4d5f8177b4b0c"},
+ {file = "ruff-0.6.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:005256d977021790cc52aa23d78f06bb5090dc0bfbd42de46d49c201533982ae"},
+ {file = "ruff-0.6.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:482c1e6bfeb615eafc5899127b805d28e387bd87db38b2c0c41d271f5e58d8cc"},
+ {file = "ruff-0.6.5-py3-none-win32.whl", hash = "sha256:cf4d3fa53644137f6a4a27a2b397381d16454a1566ae5335855c187fbf67e4f5"},
+ {file = "ruff-0.6.5-py3-none-win_amd64.whl", hash = "sha256:3e42a57b58e3612051a636bc1ac4e6b838679530235520e8f095f7c44f706ff9"},
+ {file = "ruff-0.6.5-py3-none-win_arm64.whl", hash = "sha256:51935067740773afdf97493ba9b8231279e9beef0f2a8079188c4776c25688e0"},
+ {file = "ruff-0.6.5.tar.gz", hash = "sha256:4d32d87fab433c0cf285c3683dd4dae63be05fd7a1d65b3f5bf7cdd05a6b96fb"},
]
[[package]]
@@ -1743,7 +1744,7 @@ files = [
]
[package.dependencies]
-greenlet = {version = "!=0.4.17", markers = "python_version < \"3.13\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"}
+greenlet = {version = "!=0.4.17", markers = "python_version < \"3.13\" and (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\")"}
typing-extensions = ">=4.6.0"
[package.extras]
@@ -1966,4 +1967,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<4.0"
-content-hash = "5a74dacc8f9b1b40bb9d53fbbdcb0a95f5d05d55ffd9d61af870ca8a731954b4"
+content-hash = "56197730e020f77ee9824292f34348bbe935b42519b4027f6fb131084b88300b"
diff --git a/pyproject.toml b/pyproject.toml
index 41e17e04..3f909228 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,7 +30,7 @@ rank-bm25 = "^0.2.2"
pytest = "^8.2.2"
[tool.poetry.group.dev.dependencies]
-ruff = "^0.4.8"
+ruff = "^0.6.5"
isort = "^5.13.2"
pytest = "^8.2.2"
@@ -38,3 +38,7 @@ pytest = "^8.2.2"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
+
+[tool.ruff]
+line-length = 120
+exclude = ["embedchain/"]
diff --git a/tests/embeddings/test_huggingface_embeddings.py b/tests/embeddings/test_huggingface_embeddings.py
index 13a36b0c..de6f5852 100644
--- a/tests/embeddings/test_huggingface_embeddings.py
+++ b/tests/embeddings/test_huggingface_embeddings.py
@@ -37,9 +37,7 @@ def test_embed_custom_model(mock_sentence_transformer):
def test_embed_with_model_kwargs(mock_sentence_transformer):
- config = BaseEmbedderConfig(
- model="all-MiniLM-L6-v2", model_kwargs={"device": "cuda"}
- )
+ config = BaseEmbedderConfig(model="all-MiniLM-L6-v2", model_kwargs={"device": "cuda"})
embedder = HuggingFaceEmbedding(config)
mock_sentence_transformer.encode.return_value = [0.7, 0.8, 0.9]
diff --git a/tests/embeddings/test_ollama_embeddings.py b/tests/embeddings/test_ollama_embeddings.py
index 821eaecf..0aa428b7 100644
--- a/tests/embeddings/test_ollama_embeddings.py
+++ b/tests/embeddings/test_ollama_embeddings.py
@@ -23,9 +23,7 @@ def test_embed_text(mock_ollama_client):
text = "Sample text to embed."
embedding = embedder.embed(text)
- mock_ollama_client.embeddings.assert_called_once_with(
- model="nomic-embed-text", prompt=text
- )
+ mock_ollama_client.embeddings.assert_called_once_with(model="nomic-embed-text", prompt=text)
assert embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
diff --git a/tests/embeddings/test_openai_embeddings.py b/tests/embeddings/test_openai_embeddings.py
index 875d5149..113a8e64 100644
--- a/tests/embeddings/test_openai_embeddings.py
+++ b/tests/embeddings/test_openai_embeddings.py
@@ -21,9 +21,7 @@ def test_embed_default_model(mock_openai_client):
result = embedder.embed("Hello world")
- mock_openai_client.embeddings.create.assert_called_once_with(
- input=["Hello world"], model="text-embedding-3-small"
- )
+ mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small")
assert result == [0.1, 0.2, 0.3]
@@ -51,9 +49,7 @@ def test_embed_removes_newlines(mock_openai_client):
result = embedder.embed("Hello\nworld")
- mock_openai_client.embeddings.create.assert_called_once_with(
- input=["Hello world"], model="text-embedding-3-small"
- )
+ mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small")
assert result == [0.7, 0.8, 0.9]
diff --git a/tests/llms/test_azure_openai.py b/tests/llms/test_azure_openai.py
index 63eb91b0..e54d244f 100644
--- a/tests/llms/test_azure_openai.py
+++ b/tests/llms/test_azure_openai.py
@@ -1,4 +1,3 @@
-
from unittest.mock import Mock, patch
import httpx
@@ -7,26 +6,28 @@ import pytest
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.azure_openai import AzureOpenAILLM
-MODEL = "gpt-4o" # or your custom deployment name
+MODEL = "gpt-4o" # or your custom deployment name
TEMPERATURE = 0.7
MAX_TOKENS = 100
TOP_P = 1.0
+
@pytest.fixture
def mock_openai_client():
- with patch('mem0.llms.azure_openai.AzureOpenAI') as mock_openai:
+ with patch("mem0.llms.azure_openai.AzureOpenAI") as mock_openai:
mock_client = Mock()
mock_openai.return_value = mock_client
yield mock_client
+
def test_generate_response_without_tools(mock_openai_client):
config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)
llm = AzureOpenAILLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello, how are you?"}
+ {"role": "user", "content": "Hello, how are you?"},
]
-
+
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_openai_client.chat.completions.create.return_value = mock_response
@@ -34,11 +35,7 @@ def test_generate_response_without_tools(mock_openai_client):
response = llm.generate_response(messages)
mock_openai_client.chat.completions.create.assert_called_once_with(
- model=MODEL,
- messages=messages,
- temperature=TEMPERATURE,
- max_tokens=MAX_TOKENS,
- top_p=TOP_P
+ model=MODEL, messages=messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
)
assert response == "I'm doing well, thank you for asking!"
@@ -48,7 +45,7 @@ def test_generate_response_with_tools(mock_openai_client):
llm = AzureOpenAILLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Add a new memory: Today is a sunny day."}
+ {"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
@@ -58,23 +55,21 @@ def test_generate_response_with_tools(mock_openai_client):
"description": "Add a memory",
"parameters": {
"type": "object",
- "properties": {
- "data": {"type": "string", "description": "Data to add to memory"}
- },
+ "properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
-
+
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
-
+
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
-
+
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_openai_client.chat.completions.create.return_value = mock_response
@@ -88,24 +83,33 @@ def test_generate_response_with_tools(mock_openai_client):
max_tokens=MAX_TOKENS,
top_p=TOP_P,
tools=tools,
- tool_choice="auto"
+ tool_choice="auto",
)
-
+
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
- assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'}
+ assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
+
def test_generate_with_http_proxies():
mock_http_client = Mock(spec=httpx.Client)
mock_http_client_instance = Mock(spec=httpx.Client)
mock_http_client.return_value = mock_http_client_instance
- with (patch("mem0.llms.azure_openai.AzureOpenAI") as mock_azure_openai,
- patch("httpx.Client", new=mock_http_client) as mock_http_client):
- config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P,
- api_key="test", http_client_proxies="http://testproxy.mem0.net:8000",
- azure_kwargs= {"api_key" : "test"})
+ with (
+ patch("mem0.llms.azure_openai.AzureOpenAI") as mock_azure_openai,
+ patch("httpx.Client", new=mock_http_client) as mock_http_client,
+ ):
+ config = BaseLlmConfig(
+ model=MODEL,
+ temperature=TEMPERATURE,
+ max_tokens=MAX_TOKENS,
+ top_p=TOP_P,
+ api_key="test",
+ http_client_proxies="http://testproxy.mem0.net:8000",
+ azure_kwargs={"api_key": "test"},
+ )
_ = AzureOpenAILLM(config)
@@ -114,6 +118,6 @@ def test_generate_with_http_proxies():
http_client=mock_http_client_instance,
azure_deployment=None,
azure_endpoint=None,
- api_version=None
+ api_version=None,
)
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
diff --git a/tests/llms/test_groq.py b/tests/llms/test_groq.py
index e7d1f51c..288b37f8 100644
--- a/tests/llms/test_groq.py
+++ b/tests/llms/test_groq.py
@@ -8,7 +8,7 @@ from mem0.llms.groq import GroqLLM
@pytest.fixture
def mock_groq_client():
- with patch('mem0.llms.groq.Groq') as mock_groq:
+ with patch("mem0.llms.groq.Groq") as mock_groq:
mock_client = Mock()
mock_groq.return_value = mock_client
yield mock_client
@@ -19,9 +19,9 @@ def test_generate_response_without_tools(mock_groq_client):
llm = GroqLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello, how are you?"}
+ {"role": "user", "content": "Hello, how are you?"},
]
-
+
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_groq_client.chat.completions.create.return_value = mock_response
@@ -29,11 +29,7 @@ def test_generate_response_without_tools(mock_groq_client):
response = llm.generate_response(messages)
mock_groq_client.chat.completions.create.assert_called_once_with(
- model="llama3-70b-8192",
- messages=messages,
- temperature=0.7,
- max_tokens=100,
- top_p=1.0
+ model="llama3-70b-8192", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
)
assert response == "I'm doing well, thank you for asking!"
@@ -43,7 +39,7 @@ def test_generate_response_with_tools(mock_groq_client):
llm = GroqLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Add a new memory: Today is a sunny day."}
+ {"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
@@ -53,23 +49,21 @@ def test_generate_response_with_tools(mock_groq_client):
"description": "Add a memory",
"parameters": {
"type": "object",
- "properties": {
- "data": {"type": "string", "description": "Data to add to memory"}
- },
+ "properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
-
+
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
-
+
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
-
+
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_groq_client.chat.completions.create.return_value = mock_response
@@ -83,11 +77,10 @@ def test_generate_response_with_tools(mock_groq_client):
max_tokens=100,
top_p=1.0,
tools=tools,
- tool_choice="auto"
+ tool_choice="auto",
)
-
+
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
- assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'}
-
\ No newline at end of file
+ assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
diff --git a/tests/llms/test_litellm.py b/tests/llms/test_litellm.py
index f4b265aa..d7be93c9 100644
--- a/tests/llms/test_litellm.py
+++ b/tests/llms/test_litellm.py
@@ -8,14 +8,15 @@ from mem0.llms import litellm
@pytest.fixture
def mock_litellm():
- with patch('mem0.llms.litellm.litellm') as mock_litellm:
+ with patch("mem0.llms.litellm.litellm") as mock_litellm:
yield mock_litellm
+
def test_generate_response_with_unsupported_model(mock_litellm):
config = BaseLlmConfig(model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1)
llm = litellm.LiteLLM(config)
messages = [{"role": "user", "content": "Hello"}]
-
+
mock_litellm.supports_function_calling.return_value = False
with pytest.raises(ValueError, match="Model 'unsupported-model' in litellm does not support function calling."):
@@ -27,9 +28,9 @@ def test_generate_response_without_tools(mock_litellm):
llm = litellm.LiteLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello, how are you?"}
+ {"role": "user", "content": "Hello, how are you?"},
]
-
+
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_litellm.completion.return_value = mock_response
@@ -38,11 +39,7 @@ def test_generate_response_without_tools(mock_litellm):
response = llm.generate_response(messages)
mock_litellm.completion.assert_called_once_with(
- model="gpt-4o",
- messages=messages,
- temperature=0.7,
- max_tokens=100,
- top_p=1.0
+ model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
)
assert response == "I'm doing well, thank you for asking!"
@@ -52,7 +49,7 @@ def test_generate_response_with_tools(mock_litellm):
llm = litellm.LiteLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Add a new memory: Today is a sunny day."}
+ {"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
@@ -62,23 +59,21 @@ def test_generate_response_with_tools(mock_litellm):
"description": "Add a memory",
"parameters": {
"type": "object",
- "properties": {
- "data": {"type": "string", "description": "Data to add to memory"}
- },
+ "properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
-
+
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
-
+
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
-
+
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_litellm.completion.return_value = mock_response
@@ -87,16 +82,10 @@ def test_generate_response_with_tools(mock_litellm):
response = llm.generate_response(messages, tools=tools)
mock_litellm.completion.assert_called_once_with(
- model="gpt-4o",
- messages=messages,
- temperature=0.7,
- max_tokens=100,
- top_p=1,
- tools=tools,
- tool_choice="auto"
+ model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1, tools=tools, tool_choice="auto"
)
-
+
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
- assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'}
+ assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
diff --git a/tests/llms/test_ollama.py b/tests/llms/test_ollama.py
index d99fd2bc..f8158335 100644
--- a/tests/llms/test_ollama.py
+++ b/tests/llms/test_ollama.py
@@ -9,61 +9,48 @@ from mem0.llms.utils.tools import ADD_MEMORY_TOOL
@pytest.fixture
def mock_ollama_client():
- with patch('mem0.llms.ollama.Client') as mock_ollama:
+ with patch("mem0.llms.ollama.Client") as mock_ollama:
mock_client = Mock()
mock_client.list.return_value = {"models": [{"name": "llama3.1:70b"}]}
mock_ollama.return_value = mock_client
yield mock_client
+
def test_generate_response_without_tools(mock_ollama_client):
config = BaseLlmConfig(model="llama3.1:70b", temperature=0.7, max_tokens=100, top_p=1.0)
llm = OllamaLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello, how are you?"}
+ {"role": "user", "content": "Hello, how are you?"},
]
-
- mock_response = {
- 'message': {"content": "I'm doing well, thank you for asking!"}
- }
+
+ mock_response = {"message": {"content": "I'm doing well, thank you for asking!"}}
mock_ollama_client.chat.return_value = mock_response
response = llm.generate_response(messages)
mock_ollama_client.chat.assert_called_once_with(
- model="llama3.1:70b",
- messages=messages,
- options={
- "temperature": 0.7,
- "num_predict": 100,
- "top_p": 1.0
- }
+ model="llama3.1:70b", messages=messages, options={"temperature": 0.7, "num_predict": 100, "top_p": 1.0}
)
assert response == "I'm doing well, thank you for asking!"
+
def test_generate_response_with_tools(mock_ollama_client):
config = BaseLlmConfig(model="llama3.1:70b", temperature=0.7, max_tokens=100, top_p=1.0)
llm = OllamaLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Add a new memory: Today is a sunny day."}
+ {"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [ADD_MEMORY_TOOL]
-
+
mock_response = {
- 'message': {
+ "message": {
"content": "I've added the memory for you.",
- "tool_calls": [
- {
- "function": {
- "name": "add_memory",
- "arguments": {"data": "Today is a sunny day."}
- }
- }
- ]
+ "tool_calls": [{"function": {"name": "add_memory", "arguments": {"data": "Today is a sunny day."}}}],
}
}
-
+
mock_ollama_client.chat.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
@@ -71,16 +58,11 @@ def test_generate_response_with_tools(mock_ollama_client):
mock_ollama_client.chat.assert_called_once_with(
model="llama3.1:70b",
messages=messages,
- options={
- "temperature": 0.7,
- "num_predict": 100,
- "top_p": 1.0
- },
- tools=tools
+ options={"temperature": 0.7, "num_predict": 100, "top_p": 1.0},
+ tools=tools,
)
-
+
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
- assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'}
-
\ No newline at end of file
+ assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py
index 204487c4..be2f6f95 100644
--- a/tests/llms/test_openai.py
+++ b/tests/llms/test_openai.py
@@ -8,7 +8,7 @@ from mem0.llms.openai import OpenAILLM
@pytest.fixture
def mock_openai_client():
- with patch('mem0.llms.openai.OpenAI') as mock_openai:
+ with patch("mem0.llms.openai.OpenAI") as mock_openai:
mock_client = Mock()
mock_openai.return_value = mock_client
yield mock_client
@@ -19,9 +19,9 @@ def test_generate_response_without_tools(mock_openai_client):
llm = OpenAILLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello, how are you?"}
+ {"role": "user", "content": "Hello, how are you?"},
]
-
+
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_openai_client.chat.completions.create.return_value = mock_response
@@ -29,11 +29,7 @@ def test_generate_response_without_tools(mock_openai_client):
response = llm.generate_response(messages)
mock_openai_client.chat.completions.create.assert_called_once_with(
- model="gpt-4o",
- messages=messages,
- temperature=0.7,
- max_tokens=100,
- top_p=1.0
+ model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
)
assert response == "I'm doing well, thank you for asking!"
@@ -43,7 +39,7 @@ def test_generate_response_with_tools(mock_openai_client):
llm = OpenAILLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Add a new memory: Today is a sunny day."}
+ {"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
@@ -53,23 +49,21 @@ def test_generate_response_with_tools(mock_openai_client):
"description": "Add a memory",
"parameters": {
"type": "object",
- "properties": {
- "data": {"type": "string", "description": "Data to add to memory"}
- },
+ "properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
-
+
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
-
+
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
-
+
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_openai_client.chat.completions.create.return_value = mock_response
@@ -77,17 +71,10 @@ def test_generate_response_with_tools(mock_openai_client):
response = llm.generate_response(messages, tools=tools)
mock_openai_client.chat.completions.create.assert_called_once_with(
- model="gpt-4o",
- messages=messages,
- temperature=0.7,
- max_tokens=100,
- top_p=1.0,
- tools=tools,
- tool_choice="auto"
+ model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto"
)
-
+
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
- assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'}
-
\ No newline at end of file
+ assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
diff --git a/tests/llms/test_together.py b/tests/llms/test_together.py
index f317d106..7c59ee41 100644
--- a/tests/llms/test_together.py
+++ b/tests/llms/test_together.py
@@ -8,7 +8,7 @@ from mem0.llms.together import TogetherLLM
@pytest.fixture
def mock_together_client():
- with patch('mem0.llms.together.Together') as mock_together:
+ with patch("mem0.llms.together.Together") as mock_together:
mock_client = Mock()
mock_together.return_value = mock_client
yield mock_client
@@ -19,9 +19,9 @@ def test_generate_response_without_tools(mock_together_client):
llm = TogetherLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello, how are you?"}
+ {"role": "user", "content": "Hello, how are you?"},
]
-
+
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_together_client.chat.completions.create.return_value = mock_response
@@ -29,11 +29,7 @@ def test_generate_response_without_tools(mock_together_client):
response = llm.generate_response(messages)
mock_together_client.chat.completions.create.assert_called_once_with(
- model="mistralai/Mixtral-8x7B-Instruct-v0.1",
- messages=messages,
- temperature=0.7,
- max_tokens=100,
- top_p=1.0
+ model="mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
)
assert response == "I'm doing well, thank you for asking!"
@@ -43,7 +39,7 @@ def test_generate_response_with_tools(mock_together_client):
llm = TogetherLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Add a new memory: Today is a sunny day."}
+ {"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
@@ -53,23 +49,21 @@ def test_generate_response_with_tools(mock_together_client):
"description": "Add a memory",
"parameters": {
"type": "object",
- "properties": {
- "data": {"type": "string", "description": "Data to add to memory"}
- },
+ "properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
-
+
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
-
+
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
-
+
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_together_client.chat.completions.create.return_value = mock_response
@@ -83,11 +77,10 @@ def test_generate_response_with_tools(mock_together_client):
max_tokens=100,
top_p=1.0,
tools=tools,
- tool_choice="auto"
+ tool_choice="auto",
)
-
+
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
- assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'}
-
\ No newline at end of file
+ assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
diff --git a/tests/test_main.py b/tests/test_main.py
index 16a672e3..8ed22245 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -4,42 +4,39 @@ from unittest.mock import Mock, patch
from mem0.memory.main import Memory
from mem0.configs.base import MemoryConfig
+
@pytest.fixture(autouse=True)
def mock_openai():
- os.environ['OPENAI_API_KEY'] = "123"
- with patch('openai.OpenAI') as mock:
+ os.environ["OPENAI_API_KEY"] = "123"
+ with patch("openai.OpenAI") as mock:
mock.return_value = Mock()
yield mock
+
@pytest.fixture
def memory_instance():
- with patch('mem0.utils.factory.EmbedderFactory') as mock_embedder, \
- patch('mem0.utils.factory.VectorStoreFactory') as mock_vector_store, \
- patch('mem0.utils.factory.LlmFactory') as mock_llm, \
- patch('mem0.memory.telemetry.capture_event'), \
- patch('mem0.memory.graph_memory.MemoryGraph'):
+ with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
+ "mem0.utils.factory.VectorStoreFactory"
+ ) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
+ "mem0.memory.telemetry.capture_event"
+ ), patch("mem0.memory.graph_memory.MemoryGraph"):
mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock()
-
+
config = MemoryConfig(version="v1.1")
config.graph_store.config = {"some_config": "value"}
return Memory(config)
-@pytest.mark.parametrize("version, enable_graph", [
- ("v1.0", False),
- ("v1.1", True)
-])
+
+@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)])
def test_add(memory_instance, version, enable_graph):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
memory_instance._add_to_vector_store = Mock(return_value=[{"memory": "Test memory", "event": "ADD"}])
memory_instance._add_to_graph = Mock(return_value=[])
- result = memory_instance.add(
- messages=[{"role": "user", "content": "Test message"}],
- user_id="test_user"
- )
+ result = memory_instance.add(messages=[{"role": "user", "content": "Test message"}], user_id="test_user")
assert "results" in result
assert result["results"] == [{"memory": "Test memory", "event": "ADD"}]
@@ -47,26 +44,27 @@ def test_add(memory_instance, version, enable_graph):
assert result["relations"] == []
memory_instance._add_to_vector_store.assert_called_once_with(
- [{"role": "user", "content": "Test message"}],
- {"user_id": "test_user"},
- {"user_id": "test_user"}
- )
-
- # Remove the conditional assertion for _add_to_graph
- memory_instance._add_to_graph.assert_called_once_with(
- [{"role": "user", "content": "Test message"}],
- {"user_id": "test_user"}
+ [{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}
)
+ # Remove the conditional assertion for _add_to_graph
+ memory_instance._add_to_graph.assert_called_once_with(
+ [{"role": "user", "content": "Test message"}], {"user_id": "test_user"}
+ )
+
+
def test_get(memory_instance):
- mock_memory = Mock(id="test_id", payload={
- "data": "Test memory",
- "user_id": "test_user",
- "hash": "test_hash",
- "created_at": "2023-01-01T00:00:00",
- "updated_at": "2023-01-02T00:00:00",
- "extra_field": "extra_value"
- })
+ mock_memory = Mock(
+ id="test_id",
+ payload={
+ "data": "Test memory",
+ "user_id": "test_user",
+ "hash": "test_hash",
+ "created_at": "2023-01-01T00:00:00",
+ "updated_at": "2023-01-02T00:00:00",
+ "extra_field": "extra_value",
+ },
+ )
memory_instance.vector_store.get = Mock(return_value=mock_memory)
result = memory_instance.get("test_id")
@@ -79,16 +77,14 @@ def test_get(memory_instance):
assert result["updated_at"] == "2023-01-02T00:00:00"
assert result["metadata"] == {"extra_field": "extra_value"}
-@pytest.mark.parametrize("version, enable_graph", [
- ("v1.0", False),
- ("v1.1", True)
-])
+
+@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)])
def test_search(memory_instance, version, enable_graph):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
mock_memories = [
Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"}, score=0.9),
- Mock(id="2", payload={"data": "Memory 2", "user_id": "test_user"}, score=0.8)
+ Mock(id="2", payload={"data": "Memory 2", "user_id": "test_user"}, score=0.8),
]
memory_instance.vector_store.search = Mock(return_value=mock_memories)
memory_instance.embedding_model.embed = Mock(return_value=[0.1, 0.2, 0.3])
@@ -118,17 +114,16 @@ def test_search(memory_instance, version, enable_graph):
assert result["results"][0]["score"] == 0.9
memory_instance.vector_store.search.assert_called_once_with(
- query=[0.1, 0.2, 0.3],
- limit=100,
- filters={"user_id": "test_user"}
+ query=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"}
)
memory_instance.embedding_model.embed.assert_called_once_with("test query")
-
+
if enable_graph:
memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"})
else:
memory_instance.graph.search.assert_not_called()
+
def test_update(memory_instance):
memory_instance._update_memory = Mock()
@@ -137,6 +132,7 @@ def test_update(memory_instance):
memory_instance._update_memory.assert_called_once_with("test_id", "Updated memory")
assert result["message"] == "Memory updated successfully!"
+
def test_delete(memory_instance):
memory_instance._delete_memory = Mock()
@@ -145,10 +141,8 @@ def test_delete(memory_instance):
memory_instance._delete_memory.assert_called_once_with("test_id")
assert result["message"] == "Memory deleted successfully!"
-@pytest.mark.parametrize("version, enable_graph", [
- ("v1.0", False),
- ("v1.1", True)
-])
+
+@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)])
def test_delete_all(memory_instance, version, enable_graph):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
@@ -160,14 +154,15 @@ def test_delete_all(memory_instance, version, enable_graph):
result = memory_instance.delete_all(user_id="test_user")
assert memory_instance._delete_memory.call_count == 2
-
+
if enable_graph:
memory_instance.graph.delete_all.assert_called_once_with({"user_id": "test_user"})
else:
memory_instance.graph.delete_all.assert_not_called()
-
+
assert result["message"] == "Memories deleted successfully!"
+
def test_reset(memory_instance):
memory_instance.vector_store.delete_col = Mock()
memory_instance.db.reset = Mock()
@@ -177,22 +172,30 @@ def test_reset(memory_instance):
memory_instance.vector_store.delete_col.assert_called_once()
memory_instance.db.reset.assert_called_once()
-@pytest.mark.parametrize("version, enable_graph, expected_result", [
- ("v1.0", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
- ("v1.1", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
- ("v1.1", True, {
- "results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}],
- "relations": [{"source": "entity1", "relationship": "rel", "target": "entity2"}]
- })
-])
+
+@pytest.mark.parametrize(
+ "version, enable_graph, expected_result",
+ [
+ ("v1.0", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
+ ("v1.1", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
+ (
+ "v1.1",
+ True,
+ {
+ "results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}],
+ "relations": [{"source": "entity1", "relationship": "rel", "target": "entity2"}],
+ },
+ ),
+ ],
+)
def test_get_all(memory_instance, version, enable_graph, expected_result):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
mock_memories = [Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"})]
memory_instance.vector_store.list = Mock(return_value=(mock_memories, None))
- memory_instance.graph.get_all = Mock(return_value=[
- {"source": "entity1", "relationship": "rel", "target": "entity2"}
- ])
+ memory_instance.graph.get_all = Mock(
+ return_value=[{"source": "entity1", "relationship": "rel", "target": "entity2"}]
+ )
result = memory_instance.get_all(user_id="test_user")
@@ -204,7 +207,7 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
assert result_item["id"] == expected_item["id"]
assert result_item["memory"] == expected_item["memory"]
assert result_item["user_id"] == expected_item["user_id"]
-
+
if enable_graph:
assert "relations" in result
assert result["relations"] == expected_result["relations"]
@@ -212,7 +215,7 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
assert "relations" not in result
memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100)
-
+
if enable_graph:
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"})
else:
diff --git a/tests/test_memory.py b/tests/test_memory.py
index 9c1c6003..2659d06c 100644
--- a/tests/test_memory.py
+++ b/tests/test_memory.py
@@ -7,6 +7,7 @@ from mem0 import Memory
def memory_store():
return Memory()
+
@pytest.mark.skip(reason="Not implemented")
def test_create_memory(memory_store):
data = "Name is John Doe."
diff --git a/tests/test_proxy.py b/tests/test_proxy.py
index 8e7e58ec..8088f380 100644
--- a/tests/test_proxy.py
+++ b/tests/test_proxy.py
@@ -11,23 +11,26 @@ from mem0.proxy.main import Chat, Completions, Mem0
def mock_memory_client():
return Mock(spec=MemoryClient)
+
@pytest.fixture
def mock_openai_embedding_client():
- with patch('mem0.embeddings.openai.OpenAI') as mock_openai:
+ with patch("mem0.embeddings.openai.OpenAI") as mock_openai:
mock_client = Mock()
mock_openai.return_value = mock_client
yield mock_client
+
@pytest.fixture
def mock_openai_llm_client():
- with patch('mem0.llms.openai.OpenAI') as mock_openai:
+ with patch("mem0.llms.openai.OpenAI") as mock_openai:
mock_client = Mock()
mock_openai.return_value = mock_client
yield mock_client
+
@pytest.fixture
def mock_litellm():
- with patch('mem0.proxy.main.litellm') as mock:
+ with patch("mem0.proxy.main.litellm") as mock:
yield mock
@@ -39,16 +42,16 @@ def test_mem0_initialization_with_api_key(mock_openai_embedding_client, mock_ope
def test_mem0_initialization_with_config():
config = {"some_config": "value"}
- with patch('mem0.Memory.from_config') as mock_from_config:
+ with patch("mem0.Memory.from_config") as mock_from_config:
mem0 = Mem0(config=config)
mock_from_config.assert_called_once_with(config)
assert isinstance(mem0.chat, Chat)
def test_mem0_initialization_without_params(mock_openai_embedding_client, mock_openai_llm_client):
- mem0 = Mem0()
- assert isinstance(mem0.mem0_client, Memory)
- assert isinstance(mem0.chat, Chat)
+ mem0 = Mem0()
+ assert isinstance(mem0.mem0_client, Memory)
+ assert isinstance(mem0.chat, Chat)
def test_chat_initialization(mock_memory_client):
@@ -58,48 +61,37 @@ def test_chat_initialization(mock_memory_client):
def test_completions_create(mock_memory_client, mock_litellm):
completions = Completions(mock_memory_client)
-
- messages = [
- {"role": "user", "content": "Hello, how are you?"}
- ]
+
+ messages = [{"role": "user", "content": "Hello, how are you?"}]
mock_memory_client.search.return_value = [{"memory": "Some relevant memory"}]
mock_litellm.completion.return_value = {"choices": [{"message": {"content": "I'm doing well, thank you!"}}]}
-
- response = completions.create(
- model="gpt-4o-mini",
- messages=messages,
- user_id="test_user",
- temperature=0.7
- )
-
+
+ response = completions.create(model="gpt-4o-mini", messages=messages, user_id="test_user", temperature=0.7)
+
mock_memory_client.add.assert_called_once()
mock_memory_client.search.assert_called_once()
-
+
mock_litellm.completion.assert_called_once()
call_args = mock_litellm.completion.call_args[1]
- assert call_args['model'] == "gpt-4o-mini"
- assert len(call_args['messages']) == 2
- assert call_args['temperature'] == 0.7
-
+ assert call_args["model"] == "gpt-4o-mini"
+ assert len(call_args["messages"]) == 2
+ assert call_args["temperature"] == 0.7
+
assert response == {"choices": [{"message": {"content": "I'm doing well, thank you!"}}]}
def test_completions_create_with_system_message(mock_memory_client, mock_litellm):
completions = Completions(mock_memory_client)
-
+
messages = [
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello, how are you?"}
+ {"role": "user", "content": "Hello, how are you?"},
]
mock_memory_client.search.return_value = [{"memory": "Some relevant memory"}]
mock_litellm.completion.return_value = {"choices": [{"message": {"content": "I'm doing well, thank you!"}}]}
-
- completions.create(
- model="gpt-4o-mini",
- messages=messages,
- user_id="test_user"
- )
-
+
+ completions.create(model="gpt-4o-mini", messages=messages, user_id="test_user")
+
call_args = mock_litellm.completion.call_args[1]
- assert call_args['messages'][0]['role'] == "system"
- assert call_args['messages'][0]['content'] == MEMORY_ANSWER_PROMPT
+ assert call_args["messages"][0]["role"] == "system"
+ assert call_args["messages"][0]["content"] == MEMORY_ANSWER_PROMPT
diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py
index aa36f2f0..1d0b100f 100644
--- a/tests/test_telemetry.py
+++ b/tests/test_telemetry.py
@@ -7,23 +7,28 @@ MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
if isinstance(MEM0_TELEMETRY, str):
MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes")
+
def use_telemetry():
- if os.getenv('MEM0_TELEMETRY', "true").lower() == "true":
+ if os.getenv("MEM0_TELEMETRY", "true").lower() == "true":
return True
return False
+
@pytest.fixture(autouse=True)
def reset_env():
with patch.dict(os.environ, {}, clear=True):
yield
+
def test_telemetry_enabled():
- with patch.dict(os.environ, {'MEM0_TELEMETRY': "true"}):
+ with patch.dict(os.environ, {"MEM0_TELEMETRY": "true"}):
assert use_telemetry() is True
+
def test_telemetry_disabled():
- with patch.dict(os.environ, {'MEM0_TELEMETRY': "false"}):
+ with patch.dict(os.environ, {"MEM0_TELEMETRY": "false"}):
assert use_telemetry() is False
+
def test_telemetry_default_enabled():
assert use_telemetry() is True