Rename embedchain to mem0 and open sourcing code for long term memory (#1474)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
5
mem0/__init__.py
Normal file
5
mem0/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version(__package__ or __name__)
|
||||
|
||||
from mem0.memory.main import Memory # noqa
|
||||
0
mem0/configs/__init__.py
Normal file
0
mem0/configs/__init__.py
Normal file
0
mem0/configs/base.py
Normal file
0
mem0/configs/base.py
Normal file
17
mem0/configs/prompts.py
Normal file
17
mem0/configs/prompts.py
Normal file
@@ -0,0 +1,17 @@
|
||||
UPDATE_MEMORY_PROMPT = """
|
||||
You are an expert at merging, updating, and organizing user memories. When provided with existing memories and new information, your task is to merge and update the memory list to reflect the most accurate and current information. You are also provided with the matching score for each existing memory to the new information. Make sure to leverage this information to make informed decisions about which memories to update or merge.
|
||||
|
||||
Guidelines:
|
||||
- Eliminate duplicate memories and merge related memories to ensure a concise and updated list.
|
||||
- If a memory is directly contradicted by new information, critically evaluate both pieces of information:
|
||||
- If the new memory provides a more recent or accurate update, replace the old memory with new one.
|
||||
- If the new memory seems inaccurate or less detailed, retain the original and discard the old one.
|
||||
- Maintain a consistent and clear style throughout all memories, ensuring each entry is concise yet informative.
|
||||
- If the new memory is a variation or extension of an existing memory, update the existing memory to reflect the new information.
|
||||
|
||||
Here are the details of the task:
|
||||
- Existing Memories:
|
||||
{existing_memories}
|
||||
|
||||
- New Memory: {memory}
|
||||
"""
|
||||
0
mem0/embeddings/__init__.py
Normal file
0
mem0/embeddings/__init__.py
Normal file
16
mem0/embeddings/base.py
Normal file
16
mem0/embeddings/base.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class EmbeddingBase(ABC):
|
||||
@abstractmethod
|
||||
def embed(self, text):
|
||||
"""
|
||||
Get the embedding for the given text.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
pass
|
||||
19
mem0/embeddings/huggingface.py
Normal file
19
mem0/embeddings/huggingface.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from embedding.base import EmbeddingBase
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
class HuggingFaceEmbedding(EmbeddingBase):
|
||||
def __init__(self, model_name="multi-qa-MiniLM-L6-cos-v1"):
|
||||
self.model = SentenceTransformer(model_name)
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""
|
||||
Get the embedding for the given text using Hugging Face.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
return self.model.encode(text)
|
||||
30
mem0/embeddings/ollama.py
Normal file
30
mem0/embeddings/ollama.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import ollama
|
||||
from embedding.base import EmbeddingBase
|
||||
|
||||
|
||||
class OllamaEmbedding(EmbeddingBase):
|
||||
def __init__(self, model="nomic-embed-text"):
|
||||
self.model = model
|
||||
self._ensure_model_exists()
|
||||
self.dims = 512
|
||||
|
||||
def _ensure_model_exists(self):
|
||||
"""
|
||||
Ensure the specified model exists locally. If not, pull it from Ollama.
|
||||
"""
|
||||
model_list = [m["name"] for m in ollama.list()["models"]]
|
||||
if not any(m.startswith(self.model) for m in model_list):
|
||||
ollama.pull(self.model)
|
||||
|
||||
def embed(self, text):
|
||||
"""
|
||||
Get the embedding for the given text using Ollama.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
response = ollama.embeddings(model=self.model, prompt=text)
|
||||
return response["embedding"]
|
||||
27
mem0/embeddings/openai.py
Normal file
27
mem0/embeddings/openai.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
|
||||
class OpenAIEmbedding(EmbeddingBase):
|
||||
def __init__(self, model="text-embedding-3-small"):
|
||||
self.client = OpenAI()
|
||||
self.model = model
|
||||
self.dims = 1536
|
||||
|
||||
def embed(self, text):
|
||||
"""
|
||||
Get the embedding for the given text using OpenAI.
|
||||
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
self.client.embeddings.create(input=[text], model=self.model)
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
0
mem0/llms/__init__.py
Normal file
0
mem0/llms/__init__.py
Normal file
16
mem0/llms/base.py
Normal file
16
mem0/llms/base.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class LLMBase(ABC):
|
||||
@abstractmethod
|
||||
def generate_response(self, messages):
|
||||
"""
|
||||
Generate a response based on the given messages.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
pass
|
||||
29
mem0/llms/ollama.py
Normal file
29
mem0/llms/ollama.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import ollama
|
||||
from llm.base import LLMBase
|
||||
|
||||
|
||||
class OllamaLLM(LLMBase):
|
||||
def __init__(self, model="llama3"):
|
||||
self.model = model
|
||||
self._ensure_model_exists()
|
||||
|
||||
def _ensure_model_exists(self):
|
||||
"""
|
||||
Ensure the specified model exists locally. If not, pull it from Ollama.
|
||||
"""
|
||||
model_list = [m["name"] for m in ollama.list()["models"]]
|
||||
if not any(m.startswith(self.model) for m in model_list):
|
||||
ollama.pull(self.model)
|
||||
|
||||
def generate_response(self, messages):
|
||||
"""
|
||||
Generate a response based on the given messages using Ollama.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
response = ollama.chat(model=self.model, messages=messages)
|
||||
return response["message"]["content"]
|
||||
41
mem0/llms/openai.py
Normal file
41
mem0/llms/openai.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class OpenAILLM(LLMBase):
|
||||
def __init__(self, model="gpt-4o"):
|
||||
self.client = OpenAI()
|
||||
self.model = model
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using OpenAI.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {"model": self.model, "messages": messages}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return response
|
||||
# return response.choices[0].message["content"]
|
||||
0
mem0/llms/utils/__init__.py
Normal file
0
mem0/llms/utils/__init__.py
Normal file
0
mem0/llms/utils/functions.py
Normal file
0
mem0/llms/utils/functions.py
Normal file
54
mem0/llms/utils/tools.py
Normal file
54
mem0/llms/utils/tools.py
Normal file
@@ -0,0 +1,54 @@
|
||||
ADD_MEMORY_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {"type": "string", "description": "Data to add to memory"}
|
||||
},
|
||||
"required": ["data"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
UPDATE_MEMORY_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "update_memory",
|
||||
"description": "Update memory provided ID and data",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"memory_id": {
|
||||
"type": "string",
|
||||
"description": "memory_id of the memory to update",
|
||||
},
|
||||
"data": {
|
||||
"type": "string",
|
||||
"description": "Updated data for the memory",
|
||||
},
|
||||
},
|
||||
"required": ["memory_id", "data"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
DELETE_MEMORY_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "delete_memory",
|
||||
"description": "Delete memory by memory_id",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"memory_id": {
|
||||
"type": "string",
|
||||
"description": "memory_id of the memory to delete",
|
||||
}
|
||||
},
|
||||
"required": ["memory_id"],
|
||||
},
|
||||
},
|
||||
}
|
||||
0
mem0/memory/__init__.py
Normal file
0
mem0/memory/__init__.py
Normal file
63
mem0/memory/base.py
Normal file
63
mem0/memory/base.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class MemoryBase(ABC):
|
||||
@abstractmethod
|
||||
def get(self, memory_id):
|
||||
"""
|
||||
Retrieve a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: Retrieved memory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self):
|
||||
"""
|
||||
List all memories.
|
||||
|
||||
Returns:
|
||||
list: List of all memories.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, memory_id, data):
|
||||
"""
|
||||
Update a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to update.
|
||||
data (dict): Data to update the memory with.
|
||||
|
||||
Returns:
|
||||
dict: Updated memory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, memory_id):
|
||||
"""
|
||||
Delete a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to delete.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def history(self, memory_id):
|
||||
"""
|
||||
Get the history of changes for a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to get history for.
|
||||
|
||||
Returns:
|
||||
list: List of changes for the memory.
|
||||
"""
|
||||
pass
|
||||
401
mem0/memory/main.py
Normal file
401
mem0/memory/main.py
Normal file
@@ -0,0 +1,401 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from mem0.embeddings.openai import OpenAIEmbedding
|
||||
from mem0.llms.openai import OpenAILLM
|
||||
from mem0.llms.utils.tools import (
|
||||
ADD_MEMORY_TOOL,
|
||||
DELETE_MEMORY_TOOL,
|
||||
UPDATE_MEMORY_TOOL,
|
||||
)
|
||||
from mem0.memory.base import MemoryBase
|
||||
from mem0.memory.setup import mem0_dir, setup_config
|
||||
from mem0.memory.storage import SQLiteManager
|
||||
from mem0.memory.telemetry import capture_event
|
||||
from mem0.memory.utils import get_update_memory_messages
|
||||
from mem0.vector_stores.configs import VectorStoreConfig
|
||||
from mem0.vector_stores.qdrant import Qdrant
|
||||
|
||||
# Setup user config
|
||||
setup_config()
|
||||
|
||||
|
||||
class MemoryItem(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the text data")
|
||||
text: str = Field(..., description="The text content")
|
||||
# The metadata value can be anything and not just string. Fix it
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata for the text data"
|
||||
)
|
||||
score: Optional[float] = Field(
|
||||
None, description="The score associated with the text data"
|
||||
)
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
vector_store: VectorStoreConfig = Field(
|
||||
description="Configuration for the vector store",
|
||||
default_factory=VectorStoreConfig,
|
||||
)
|
||||
history_db_path: str = Field(
|
||||
description="Path to the history database",
|
||||
default=os.path.join(mem0_dir, "history.db"),
|
||||
)
|
||||
collection_name: str = Field(default="mem0", description="Name of the collection")
|
||||
embedding_model_dims: int = Field(
|
||||
default=1536, description="Dimensions of the embedding model"
|
||||
)
|
||||
|
||||
|
||||
class Memory(MemoryBase):
|
||||
def __init__(self, config: MemoryConfig = MemoryConfig()):
|
||||
self.config = config
|
||||
self.embedding_model = OpenAIEmbedding()
|
||||
# Initialize the appropriate vector store based on the configuration
|
||||
vector_store_config = self.config.vector_store.config
|
||||
if self.config.vector_store.provider == "qdrant":
|
||||
self.vector_store = Qdrant(
|
||||
host=vector_store_config.host,
|
||||
port=vector_store_config.port,
|
||||
path=vector_store_config.path,
|
||||
url=vector_store_config.url,
|
||||
api_key=vector_store_config.api_key,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported vector store type: {self.config.vector_store_type}"
|
||||
)
|
||||
|
||||
self.llm = OpenAILLM()
|
||||
self.db = SQLiteManager(self.config.history_db_path)
|
||||
self.collection_name = self.config.collection_name
|
||||
self.vector_store.create_col(
|
||||
name=self.collection_name, vector_size=self.embedding_model.dims
|
||||
)
|
||||
self.vector_store.create_col(
|
||||
name=self.collection_name, vector_size=self.embedding_model.dims
|
||||
)
|
||||
capture_event("mem0.init", self)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config_dict: Dict[str, Any]):
|
||||
try:
|
||||
config = MemoryConfig(**config_dict)
|
||||
except ValidationError as e:
|
||||
logging.error(f"Configuration validation error: {e}")
|
||||
raise
|
||||
return cls(config)
|
||||
|
||||
def add(
|
||||
self,
|
||||
data,
|
||||
user_id=None,
|
||||
agent_id=None,
|
||||
run_id=None,
|
||||
metadata=None,
|
||||
filters=None,
|
||||
):
|
||||
"""
|
||||
Create a new memory.
|
||||
|
||||
Args:
|
||||
data (str): Data to store in the memory.
|
||||
user_id (str, optional): ID of the user creating the memory. Defaults to None.
|
||||
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
|
||||
run_id (str, optional): ID of the run creating the memory. Defaults to None.
|
||||
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: ID of the created memory.
|
||||
"""
|
||||
if metadata is None:
|
||||
logging.warn("Metadata not provided. Using empty metadata.")
|
||||
metadata = {}
|
||||
embeddings = self.embedding_model.embed(data)
|
||||
|
||||
filters = filters or {}
|
||||
if user_id:
|
||||
filters["user_id"] = metadata["user_id"] = user_id
|
||||
if agent_id:
|
||||
filters["agent_id"] = metadata["agent_id"] = agent_id
|
||||
if run_id:
|
||||
filters["run_id"] = metadata["run_id"] = run_id
|
||||
|
||||
existing_memories = self.vector_store.search(
|
||||
name=self.collection_name,
|
||||
query=embeddings,
|
||||
limit=5,
|
||||
filters=filters,
|
||||
)
|
||||
existing_memories = [
|
||||
MemoryItem(
|
||||
id=mem.id,
|
||||
score=mem.score,
|
||||
metadata=mem.payload,
|
||||
text=mem.payload["data"],
|
||||
)
|
||||
for mem in existing_memories
|
||||
]
|
||||
serialized_existing_memories = [
|
||||
item.model_dump(include={"id", "text", "score"})
|
||||
for item in existing_memories
|
||||
]
|
||||
logging.info(f"Total existing memories: {len(existing_memories)}")
|
||||
messages = get_update_memory_messages(serialized_existing_memories, data)
|
||||
# Add tools for noop, add, update, delete memory.
|
||||
tools = [ADD_MEMORY_TOOL, UPDATE_MEMORY_TOOL, DELETE_MEMORY_TOOL]
|
||||
response = self.llm.generate_response(messages=messages, tools=tools)
|
||||
response_message = response.choices[0].message
|
||||
tool_calls = response_message.tool_calls
|
||||
|
||||
response = []
|
||||
if tool_calls:
|
||||
# Create a new memory
|
||||
available_functions = {
|
||||
"add_memory": self._create_memory_tool,
|
||||
"update_memory": self._update_memory_tool,
|
||||
"delete_memory": self._delete_memory_tool,
|
||||
}
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
function_to_call = available_functions[function_name]
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
logging.info(
|
||||
f"[openai_func] func: {function_name}, args: {function_args}"
|
||||
)
|
||||
|
||||
# Pass metadata to the function if it requires it
|
||||
if function_name in ["add_memory", "update_memory"]:
|
||||
function_args["metadata"] = metadata
|
||||
|
||||
function_result = function_to_call(**function_args)
|
||||
# Fetch the memory_id from the response
|
||||
response.append(
|
||||
{
|
||||
"id": function_result,
|
||||
"event": function_name.replace("_memory", ""),
|
||||
"data": function_args.get("data"),
|
||||
}
|
||||
)
|
||||
capture_event(
|
||||
"mem0.add.function_call",
|
||||
self,
|
||||
{"memory_id": function_result, "function_name": function_name},
|
||||
)
|
||||
capture_event("mem0.add", self)
|
||||
return response
|
||||
|
||||
def get(self, memory_id):
|
||||
"""
|
||||
Retrieve a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: Retrieved memory.
|
||||
"""
|
||||
capture_event("mem0.get", self, {"memory_id": memory_id})
|
||||
memory = self.vector_store.get(name=self.collection_name, vector_id=memory_id)
|
||||
if not memory:
|
||||
return None
|
||||
return MemoryItem(
|
||||
id=memory.id,
|
||||
metadata=memory.payload,
|
||||
text=memory.payload["data"],
|
||||
).model_dump(exclude={"score"})
|
||||
|
||||
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
|
||||
"""
|
||||
List all memories.
|
||||
|
||||
Returns:
|
||||
list: List of all memories.
|
||||
"""
|
||||
filters = {}
|
||||
if user_id:
|
||||
filters["user_id"] = user_id
|
||||
if agent_id:
|
||||
filters["agent_id"] = agent_id
|
||||
if run_id:
|
||||
filters["run_id"] = run_id
|
||||
|
||||
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
|
||||
memories = self.vector_store.list(
|
||||
name=self.collection_name, filters=filters, limit=limit
|
||||
)
|
||||
return [
|
||||
MemoryItem(
|
||||
id=mem.id,
|
||||
metadata=mem.payload,
|
||||
text=mem.payload["data"],
|
||||
).model_dump(exclude={"score"})
|
||||
for mem in memories[0]
|
||||
]
|
||||
|
||||
def search(
|
||||
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
|
||||
):
|
||||
"""
|
||||
Search for memories.
|
||||
|
||||
Args:
|
||||
query (str): Query to search for.
|
||||
user_id (str, optional): ID of the user to search for. Defaults to None.
|
||||
agent_id (str, optional): ID of the agent to search for. Defaults to None.
|
||||
run_id (str, optional): ID of the run to search for. Defaults to None.
|
||||
limit (int, optional): Limit the number of results. Defaults to 100.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: List of search results.
|
||||
"""
|
||||
filters = filters or {}
|
||||
if user_id:
|
||||
filters["user_id"] = user_id
|
||||
if agent_id:
|
||||
filters["agent_id"] = agent_id
|
||||
if run_id:
|
||||
filters["run_id"] = run_id
|
||||
|
||||
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit})
|
||||
embeddings = self.embedding_model.embed(query)
|
||||
memories = self.vector_store.search(
|
||||
name=self.collection_name, query=embeddings, limit=limit, filters=filters
|
||||
)
|
||||
return [
|
||||
MemoryItem(
|
||||
id=mem.id,
|
||||
metadata=mem.payload,
|
||||
score=mem.score,
|
||||
text=mem.payload["data"],
|
||||
).model_dump()
|
||||
for mem in memories
|
||||
]
|
||||
|
||||
def update(self, memory_id, data):
|
||||
"""
|
||||
Update a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to update.
|
||||
data (dict): Data to update the memory with.
|
||||
|
||||
Returns:
|
||||
dict: Updated memory.
|
||||
"""
|
||||
capture_event("mem0.get_all", self, {"memory_id": memory_id})
|
||||
self._update_memory_tool(memory_id, data)
|
||||
|
||||
def delete(self, memory_id):
|
||||
"""
|
||||
Delete a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to delete.
|
||||
"""
|
||||
capture_event("mem0.delete", self, {"memory_id": memory_id})
|
||||
self._delete_memory_tool(memory_id)
|
||||
|
||||
def delete_all(self, user_id=None, agent_id=None, run_id=None):
|
||||
"""
|
||||
Delete all memories.
|
||||
|
||||
Args:
|
||||
user_id (str, optional): ID of the user to delete memories for. Defaults to None.
|
||||
agent_id (str, optional): ID of the agent to delete memories for. Defaults to None.
|
||||
run_id (str, optional): ID of the run to delete memories for. Defaults to None.
|
||||
"""
|
||||
filters = {}
|
||||
if user_id:
|
||||
filters["user_id"] = user_id
|
||||
if agent_id:
|
||||
filters["agent_id"] = agent_id
|
||||
if run_id:
|
||||
filters["run_id"] = run_id
|
||||
|
||||
if not filters:
|
||||
raise ValueError(
|
||||
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
|
||||
)
|
||||
|
||||
capture_event("mem0.delete_all", self, {"filters": len(filters)})
|
||||
memories = self.vector_store.list(name=self.collection_name, filters=filters)[0]
|
||||
for memory in memories:
|
||||
self._delete_memory_tool(memory.id)
|
||||
|
||||
def history(self, memory_id):
|
||||
"""
|
||||
Get the history of changes for a memory by ID.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to get history for.
|
||||
|
||||
Returns:
|
||||
list: List of changes for the memory.
|
||||
"""
|
||||
capture_event("mem0.history", self, {"memory_id": memory_id})
|
||||
return self.db.get_history(memory_id)
|
||||
|
||||
def _create_memory_tool(self, data, metadata=None):
|
||||
logging.info(f"Creating memory with {data=}")
|
||||
embeddings = self.embedding_model.embed(data)
|
||||
memory_id = str(uuid.uuid4())
|
||||
metadata = metadata or {}
|
||||
metadata["data"] = data
|
||||
metadata["created_at"] = int(time.time())
|
||||
|
||||
self.vector_store.insert(
|
||||
name=self.collection_name,
|
||||
vectors=[embeddings],
|
||||
ids=[memory_id],
|
||||
payloads=[metadata],
|
||||
)
|
||||
self.db.add_history(memory_id, None, data, "add")
|
||||
return memory_id
|
||||
|
||||
def _update_memory_tool(self, memory_id, data, metadata=None):
|
||||
existing_memory = self.vector_store.get(
|
||||
name=self.collection_name, vector_id=memory_id
|
||||
)
|
||||
prev_value = existing_memory.payload.get("data")
|
||||
|
||||
new_metadata = metadata or {}
|
||||
new_metadata["data"] = data
|
||||
new_metadata["updated_at"] = int(time.time())
|
||||
embeddings = self.embedding_model.embed(data)
|
||||
self.vector_store.update(
|
||||
name=self.collection_name,
|
||||
vector_id=memory_id,
|
||||
vector=embeddings,
|
||||
payload=new_metadata,
|
||||
)
|
||||
logging.info(f"Updating memory with ID {memory_id=} with {data=}")
|
||||
self.db.add_history(memory_id, prev_value, data, "update")
|
||||
|
||||
def _delete_memory_tool(self, memory_id):
|
||||
logging.info(f"Deleting memory with {memory_id=}")
|
||||
existing_memory = self.vector_store.get(
|
||||
name=self.collection_name, vector_id=memory_id
|
||||
)
|
||||
prev_value = existing_memory.payload["data"]
|
||||
self.vector_store.delete(name=self.collection_name, vector_id=memory_id)
|
||||
self.db.add_history(memory_id, prev_value, None, "delete", is_deleted=1)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the memory store.
|
||||
"""
|
||||
self.vector_store.delete_col(name=self.collection_name)
|
||||
self.db.reset()
|
||||
capture_event("mem0.reset", self)
|
||||
|
||||
def chat(self, query):
|
||||
raise NotImplementedError("Chat function not implemented yet.")
|
||||
28
mem0/memory/setup.py
Normal file
28
mem0/memory/setup.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
# Set up the directory path
|
||||
home_dir = os.path.expanduser("~")
|
||||
mem0_dir = os.path.join(home_dir, ".mem0")
|
||||
os.makedirs(mem0_dir, exist_ok=True)
|
||||
|
||||
|
||||
def setup_config():
|
||||
config_path = os.path.join(mem0_dir, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
user_id = str(uuid.uuid4())
|
||||
config = {"user_id": user_id}
|
||||
with open(config_path, "w") as config_file:
|
||||
json.dump(config, config_file, indent=4)
|
||||
|
||||
|
||||
def get_user_id():
|
||||
config_path = os.path.join(mem0_dir, "config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r") as config_file:
|
||||
config = json.load(config_file)
|
||||
user_id = config.get("user_id")
|
||||
return user_id
|
||||
else:
|
||||
return "anonymous_user"
|
||||
71
mem0/memory/storage.py
Normal file
71
mem0/memory/storage.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class SQLiteManager:
|
||||
def __init__(self, db_path=":memory:"):
|
||||
self.connection = sqlite3.connect(db_path, check_same_thread=False)
|
||||
self._create_history_table()
|
||||
|
||||
def _create_history_table(self):
|
||||
with self.connection:
|
||||
self.connection.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS history (
|
||||
id TEXT PRIMARY KEY,
|
||||
memory_id TEXT,
|
||||
prev_value TEXT,
|
||||
new_value TEXT,
|
||||
event TEXT,
|
||||
timestamp DATETIME,
|
||||
is_deleted INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
def add_history(self, memory_id, prev_value, new_value, event, is_deleted=0):
|
||||
with self.connection:
|
||||
self.connection.execute(
|
||||
"""
|
||||
INSERT INTO history (id, memory_id, prev_value, new_value, event, timestamp, is_deleted)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
memory_id,
|
||||
prev_value,
|
||||
new_value,
|
||||
event,
|
||||
datetime.utcnow(),
|
||||
is_deleted,
|
||||
),
|
||||
)
|
||||
|
||||
def get_history(self, memory_id):
|
||||
cursor = self.connection.execute(
|
||||
"""
|
||||
SELECT id, memory_id, prev_value, new_value, event, timestamp, is_deleted
|
||||
FROM history
|
||||
WHERE memory_id = ?
|
||||
ORDER BY timestamp ASC
|
||||
""",
|
||||
(memory_id,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
return [
|
||||
{
|
||||
"id": row[0],
|
||||
"memory_id": row[1],
|
||||
"prev_value": row[2],
|
||||
"new_value": row[3],
|
||||
"event": row[4],
|
||||
"timestamp": row[5],
|
||||
"is_deleted": row[6],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
with self.connection:
|
||||
self.connection.execute("DROP TABLE IF EXISTS history")
|
||||
61
mem0/memory/telemetry.py
Normal file
61
mem0/memory/telemetry.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import platform
|
||||
import sys
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from mem0.memory.setup import get_user_id, setup_config
|
||||
|
||||
|
||||
class AnonymousTelemetry:
|
||||
def __init__(self, project_api_key, host):
|
||||
self.posthog = Posthog(project_api_key=project_api_key, host=host)
|
||||
# Call setup config to ensure that the user_id is generated
|
||||
setup_config()
|
||||
self.user_id = get_user_id()
|
||||
|
||||
def capture_event(self, event_name, properties=None):
|
||||
if properties is None:
|
||||
properties = {}
|
||||
properties = {
|
||||
"python_version": sys.version,
|
||||
"os": sys.platform,
|
||||
"os_version": platform.version(),
|
||||
"os_release": platform.release(),
|
||||
"processor": platform.processor(),
|
||||
"machine": platform.machine(),
|
||||
**properties,
|
||||
}
|
||||
self.posthog.capture(
|
||||
distinct_id=self.user_id, event=event_name, properties=properties
|
||||
)
|
||||
|
||||
def identify_user(self, user_id, properties=None):
|
||||
if properties is None:
|
||||
properties = {}
|
||||
self.posthog.identify(distinct_id=user_id, properties=properties)
|
||||
|
||||
def close(self):
|
||||
self.posthog.shutdown()
|
||||
|
||||
|
||||
# Initialize AnonymousTelemetry
|
||||
telemetry = AnonymousTelemetry(
|
||||
project_api_key="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX",
|
||||
host="https://us.i.posthog.com",
|
||||
)
|
||||
|
||||
|
||||
def capture_event(event_name, memory_instance, additional_data=None):
|
||||
event_data = {
|
||||
"collection": memory_instance.collection_name,
|
||||
"vector_size": memory_instance.embedding_model.dims,
|
||||
"history_store": "sqlite",
|
||||
"vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
|
||||
"llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
|
||||
"embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
|
||||
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}",
|
||||
}
|
||||
if additional_data:
|
||||
event_data.update(additional_data)
|
||||
|
||||
telemetry.capture_event(event_name, event_data)
|
||||
14
mem0/memory/utils.py
Normal file
14
mem0/memory/utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from mem0.configs.prompts import UPDATE_MEMORY_PROMPT
|
||||
|
||||
|
||||
def get_update_memory_prompt(existing_memories, memory, template=UPDATE_MEMORY_PROMPT):
|
||||
return template.format(existing_memories=existing_memories, memory=memory)
|
||||
|
||||
|
||||
def get_update_memory_messages(existing_memories, memory):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": get_update_memory_prompt(existing_memories, memory),
|
||||
},
|
||||
]
|
||||
0
mem0/vector_stores/__init__.py
Normal file
0
mem0/vector_stores/__init__.py
Normal file
48
mem0/vector_stores/base.py
Normal file
48
mem0/vector_stores/base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class VectorStoreBase(ABC):
|
||||
@abstractmethod
|
||||
def create_col(self, name, vector_size, distance):
|
||||
"""Create a new collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, name, vectors, payloads=None, ids=None):
|
||||
"""Insert vectors into a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, name, query, limit=5, filters=None):
|
||||
"""Search for similar vectors."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name, vector_id):
|
||||
"""Delete a vector by ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, name, vector_id, vector=None, payload=None):
|
||||
"""Update a vector and its payload."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name, vector_id):
|
||||
"""Retrieve a vector by ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_cols(self):
|
||||
"""List all collections."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_col(self, name):
|
||||
"""Delete a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def col_info(self, name):
|
||||
"""Get information about a collection."""
|
||||
pass
|
||||
45
mem0/vector_stores/configs.py
Normal file
45
mem0/vector_stores/configs.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
host: Optional[str] = Field(None, description="Host address for Qdrant server")
|
||||
port: Optional[int] = Field(None, description="Port for Qdrant server")
|
||||
path: Optional[str] = Field(None, description="Path for local Qdrant database")
|
||||
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
|
||||
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_host_port_or_path(cls, values):
|
||||
host, port, path, url, api_key = (
|
||||
values.get("host"),
|
||||
values.get("port"),
|
||||
values.get("path"),
|
||||
values.get("url"),
|
||||
values.get("api_key"),
|
||||
)
|
||||
if not path and not (host and port) and not (url and api_key):
|
||||
raise ValueError(
|
||||
"Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided."
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
provider: str = Field(
|
||||
description="Provider of the vector store (e.g., 'qdrant', 'chromadb', 'elasticsearch')",
|
||||
default="qdrant",
|
||||
)
|
||||
config: QdrantConfig = Field(
|
||||
description="Configuration for the specific vector store",
|
||||
default=QdrantConfig(path="/tmp/qdrant"),
|
||||
)
|
||||
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
provider = values.data.get("provider")
|
||||
if provider == "qdrant":
|
||||
return QdrantConfig(**v.model_dump())
|
||||
else:
|
||||
raise ValueError(f"Unsupported vector store provider: {provider}")
|
||||
242
mem0/vector_stores/qdrant.py
Normal file
242
mem0/vector_stores/qdrant.py
Normal file
@@ -0,0 +1,242 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
PointIdsList,
|
||||
PointStruct,
|
||||
Range,
|
||||
VectorParams,
|
||||
)
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
host: Optional[str] = Field(None, description="Host address for Qdrant server")
|
||||
port: Optional[int] = Field(None, description="Port for Qdrant server")
|
||||
path: Optional[str] = Field(None, description="Path for local Qdrant database")
|
||||
|
||||
|
||||
class Qdrant(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
client=None,
|
||||
host="localhost",
|
||||
port=6333,
|
||||
path=None,
|
||||
url=None,
|
||||
api_key=None,
|
||||
):
|
||||
"""
|
||||
Initialize the Qdrant vector store.
|
||||
|
||||
Args:
|
||||
client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
|
||||
host (str, optional): Host address for Qdrant server. Defaults to "localhost".
|
||||
port (int, optional): Port for Qdrant server. Defaults to 6333.
|
||||
path (str, optional): Path for local Qdrant database. Defaults to None.
|
||||
url (str, optional): Full URL for Qdrant server. Defaults to None.
|
||||
api_key (str, optional): API key for Qdrant server. Defaults to None.
|
||||
"""
|
||||
if client:
|
||||
self.client = client
|
||||
else:
|
||||
params = {}
|
||||
if path:
|
||||
params["path"] = path
|
||||
if api_key:
|
||||
params["api_key"] = api_key
|
||||
if url:
|
||||
params["url"] = url
|
||||
if host and port:
|
||||
params["host"] = host
|
||||
params["port"] = port
|
||||
self.client = QdrantClient(**params)
|
||||
|
||||
def create_col(self, name, vector_size, distance=Distance.COSINE):
|
||||
"""
|
||||
Create a new collection.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection.
|
||||
vector_size (int): Size of the vectors to be stored.
|
||||
distance (Distance, optional): Distance metric for vector similarity. Defaults to Distance.COSINE.
|
||||
"""
|
||||
# Skip creating collection if already exists
|
||||
response = self.list_cols()
|
||||
for collection in response.collections:
|
||||
if collection.name == name:
|
||||
logging.debug(f"Collection {name} already exists. Skipping creation.")
|
||||
return
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=distance),
|
||||
)
|
||||
|
||||
def insert(self, name, vectors, payloads=None, ids=None):
|
||||
"""
|
||||
Insert vectors into a collection.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection.
|
||||
vectors (list): List of vectors to insert.
|
||||
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
|
||||
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
|
||||
"""
|
||||
points = [
|
||||
PointStruct(
|
||||
id=idx if ids is None else ids[idx],
|
||||
vector=vector,
|
||||
payload=payloads[idx] if payloads else {},
|
||||
)
|
||||
for idx, vector in enumerate(vectors)
|
||||
]
|
||||
self.client.upsert(collection_name=name, points=points)
|
||||
|
||||
def _create_filter(self, filters):
|
||||
"""
|
||||
Create a Filter object from the provided filters.
|
||||
|
||||
Args:
|
||||
filters (dict): Filters to apply.
|
||||
|
||||
Returns:
|
||||
Filter: The created Filter object.
|
||||
"""
|
||||
conditions = []
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, dict) and "gte" in value and "lte" in value:
|
||||
conditions.append(
|
||||
FieldCondition(
|
||||
key=key, range=Range(gte=value["gte"], lte=value["lte"])
|
||||
)
|
||||
)
|
||||
else:
|
||||
conditions.append(
|
||||
FieldCondition(key=key, match=MatchValue(value=value))
|
||||
)
|
||||
return Filter(must=conditions) if conditions else None
|
||||
|
||||
def search(self, name, query, limit=5, filters=None):
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection.
|
||||
query (list): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: Search results.
|
||||
"""
|
||||
query_filter = self._create_filter(filters) if filters else None
|
||||
hits = self.client.search(
|
||||
collection_name=name,
|
||||
query_vector=query,
|
||||
query_filter=query_filter,
|
||||
limit=limit,
|
||||
)
|
||||
return hits
|
||||
|
||||
def delete(self, name, vector_id):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection.
|
||||
vector_id (int): ID of the vector to delete.
|
||||
"""
|
||||
self.client.delete(
|
||||
collection_name=name,
|
||||
points_selector=PointIdsList(
|
||||
points=[vector_id],
|
||||
),
|
||||
)
|
||||
|
||||
def update(self, name, vector_id, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection.
|
||||
vector_id (int): ID of the vector to update.
|
||||
vector (list, optional): Updated vector. Defaults to None.
|
||||
payload (dict, optional): Updated payload. Defaults to None.
|
||||
"""
|
||||
point = PointStruct(id=vector_id, vector=vector, payload=payload)
|
||||
self.client.upsert(collection_name=name, points=[point])
|
||||
|
||||
def get(self, name, vector_id):
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection.
|
||||
vector_id (int): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: Retrieved vector.
|
||||
"""
|
||||
result = self.client.retrieve(
|
||||
collection_name=name, ids=[vector_id], with_payload=True
|
||||
)
|
||||
return result[0] if result else None
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all collections.
|
||||
|
||||
Returns:
|
||||
list: List of collection names.
|
||||
"""
|
||||
return self.client.get_collections()
|
||||
|
||||
def delete_col(self, name):
|
||||
"""
|
||||
Delete a collection.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection to delete.
|
||||
"""
|
||||
self.client.delete_collection(collection_name=name)
|
||||
|
||||
def col_info(self, name):
|
||||
"""
|
||||
Get information about a collection.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection.
|
||||
|
||||
Returns:
|
||||
dict: Collection information.
|
||||
"""
|
||||
return self.client.get_collection(collection_name=name)
|
||||
|
||||
def list(self, name, filters=None, limit=100):
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
list: List of vectors.
|
||||
"""
|
||||
query_filter = self._create_filter(filters) if filters else None
|
||||
result = self.client.scroll(
|
||||
collection_name=name,
|
||||
scroll_filter=query_filter,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
return result
|
||||
Reference in New Issue
Block a user