Rename embedchain to mem0 and open sourcing code for long term memory (#1474)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Taranjeet Singh
2024-07-12 07:51:33 -07:00
committed by GitHub
parent 83e8c97295
commit f842a92e25
665 changed files with 9427 additions and 6592 deletions

5
mem0/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
import importlib.metadata
__version__ = importlib.metadata.version(__package__ or __name__)
from mem0.memory.main import Memory # noqa

0
mem0/configs/__init__.py Normal file
View File

0
mem0/configs/base.py Normal file
View File

17
mem0/configs/prompts.py Normal file
View File

@@ -0,0 +1,17 @@
UPDATE_MEMORY_PROMPT = """
You are an expert at merging, updating, and organizing user memories. When provided with existing memories and new information, your task is to merge and update the memory list to reflect the most accurate and current information. You are also provided with the matching score for each existing memory to the new information. Make sure to leverage this information to make informed decisions about which memories to update or merge.
Guidelines:
- Eliminate duplicate memories and merge related memories to ensure a concise and updated list.
- If a memory is directly contradicted by new information, critically evaluate both pieces of information:
- If the new memory provides a more recent or accurate update, replace the old memory with new one.
- If the new memory seems inaccurate or less detailed, retain the original and discard the old one.
- Maintain a consistent and clear style throughout all memories, ensuring each entry is concise yet informative.
- If the new memory is a variation or extension of an existing memory, update the existing memory to reflect the new information.
Here are the details of the task:
- Existing Memories:
{existing_memories}
- New Memory: {memory}
"""

View File

16
mem0/embeddings/base.py Normal file
View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
class EmbeddingBase(ABC):
@abstractmethod
def embed(self, text):
"""
Get the embedding for the given text.
Args:
text (str): The text to embed.
Returns:
list: The embedding vector.
"""
pass

View File

@@ -0,0 +1,19 @@
from embedding.base import EmbeddingBase
from sentence_transformers import SentenceTransformer
class HuggingFaceEmbedding(EmbeddingBase):
def __init__(self, model_name="multi-qa-MiniLM-L6-cos-v1"):
self.model = SentenceTransformer(model_name)
def get_embedding(self, text):
"""
Get the embedding for the given text using Hugging Face.
Args:
text (str): The text to embed.
Returns:
list: The embedding vector.
"""
return self.model.encode(text)

30
mem0/embeddings/ollama.py Normal file
View File

@@ -0,0 +1,30 @@
import ollama
from embedding.base import EmbeddingBase
class OllamaEmbedding(EmbeddingBase):
def __init__(self, model="nomic-embed-text"):
self.model = model
self._ensure_model_exists()
self.dims = 512
def _ensure_model_exists(self):
"""
Ensure the specified model exists locally. If not, pull it from Ollama.
"""
model_list = [m["name"] for m in ollama.list()["models"]]
if not any(m.startswith(self.model) for m in model_list):
ollama.pull(self.model)
def embed(self, text):
"""
Get the embedding for the given text using Ollama.
Args:
text (str): The text to embed.
Returns:
list: The embedding vector.
"""
response = ollama.embeddings(model=self.model, prompt=text)
return response["embedding"]

27
mem0/embeddings/openai.py Normal file
View File

@@ -0,0 +1,27 @@
from openai import OpenAI
from mem0.embeddings.base import EmbeddingBase
class OpenAIEmbedding(EmbeddingBase):
def __init__(self, model="text-embedding-3-small"):
self.client = OpenAI()
self.model = model
self.dims = 1536
def embed(self, text):
"""
Get the embedding for the given text using OpenAI.
Args:
text (str): The text to embed.
Returns:
list: The embedding vector.
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(input=[text], model=self.model)
.data[0]
.embedding
)

0
mem0/llms/__init__.py Normal file
View File

16
mem0/llms/base.py Normal file
View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
class LLMBase(ABC):
@abstractmethod
def generate_response(self, messages):
"""
Generate a response based on the given messages.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
Returns:
str: The generated response.
"""
pass

29
mem0/llms/ollama.py Normal file
View File

@@ -0,0 +1,29 @@
import ollama
from llm.base import LLMBase
class OllamaLLM(LLMBase):
def __init__(self, model="llama3"):
self.model = model
self._ensure_model_exists()
def _ensure_model_exists(self):
"""
Ensure the specified model exists locally. If not, pull it from Ollama.
"""
model_list = [m["name"] for m in ollama.list()["models"]]
if not any(m.startswith(self.model) for m in model_list):
ollama.pull(self.model)
def generate_response(self, messages):
"""
Generate a response based on the given messages using Ollama.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
Returns:
str: The generated response.
"""
response = ollama.chat(model=self.model, messages=messages)
return response["message"]["content"]

41
mem0/llms/openai.py Normal file
View File

@@ -0,0 +1,41 @@
from typing import Dict, List, Optional
from openai import OpenAI
from mem0.llms.base import LLMBase
class OpenAILLM(LLMBase):
def __init__(self, model="gpt-4o"):
self.client = OpenAI()
self.model = model
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using OpenAI.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
params = {"model": self.model, "messages": messages}
if response_format:
params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return response
# return response.choices[0].message["content"]

View File

View File

54
mem0/llms/utils/tools.py Normal file
View File

@@ -0,0 +1,54 @@
ADD_MEMORY_TOOL = {
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {
"data": {"type": "string", "description": "Data to add to memory"}
},
"required": ["data"],
},
},
}
UPDATE_MEMORY_TOOL = {
"type": "function",
"function": {
"name": "update_memory",
"description": "Update memory provided ID and data",
"parameters": {
"type": "object",
"properties": {
"memory_id": {
"type": "string",
"description": "memory_id of the memory to update",
},
"data": {
"type": "string",
"description": "Updated data for the memory",
},
},
"required": ["memory_id", "data"],
},
},
}
DELETE_MEMORY_TOOL = {
"type": "function",
"function": {
"name": "delete_memory",
"description": "Delete memory by memory_id",
"parameters": {
"type": "object",
"properties": {
"memory_id": {
"type": "string",
"description": "memory_id of the memory to delete",
}
},
"required": ["memory_id"],
},
},
}

0
mem0/memory/__init__.py Normal file
View File

63
mem0/memory/base.py Normal file
View File

@@ -0,0 +1,63 @@
from abc import ABC, abstractmethod
class MemoryBase(ABC):
@abstractmethod
def get(self, memory_id):
"""
Retrieve a memory by ID.
Args:
memory_id (str): ID of the memory to retrieve.
Returns:
dict: Retrieved memory.
"""
pass
@abstractmethod
def get_all(self):
"""
List all memories.
Returns:
list: List of all memories.
"""
pass
@abstractmethod
def update(self, memory_id, data):
"""
Update a memory by ID.
Args:
memory_id (str): ID of the memory to update.
data (dict): Data to update the memory with.
Returns:
dict: Updated memory.
"""
pass
@abstractmethod
def delete(self, memory_id):
"""
Delete a memory by ID.
Args:
memory_id (str): ID of the memory to delete.
"""
pass
@abstractmethod
def history(self, memory_id):
"""
Get the history of changes for a memory by ID.
Args:
memory_id (str): ID of the memory to get history for.
Returns:
list: List of changes for the memory.
"""
pass

401
mem0/memory/main.py Normal file
View File

@@ -0,0 +1,401 @@
import json
import logging
import os
import time
import uuid
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, ValidationError
from mem0.embeddings.openai import OpenAIEmbedding
from mem0.llms.openai import OpenAILLM
from mem0.llms.utils.tools import (
ADD_MEMORY_TOOL,
DELETE_MEMORY_TOOL,
UPDATE_MEMORY_TOOL,
)
from mem0.memory.base import MemoryBase
from mem0.memory.setup import mem0_dir, setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_update_memory_messages
from mem0.vector_stores.configs import VectorStoreConfig
from mem0.vector_stores.qdrant import Qdrant
# Setup user config
setup_config()
class MemoryItem(BaseModel):
id: str = Field(..., description="The unique identifier for the text data")
text: str = Field(..., description="The text content")
# The metadata value can be anything and not just string. Fix it
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Additional metadata for the text data"
)
score: Optional[float] = Field(
None, description="The score associated with the text data"
)
class MemoryConfig(BaseModel):
vector_store: VectorStoreConfig = Field(
description="Configuration for the vector store",
default_factory=VectorStoreConfig,
)
history_db_path: str = Field(
description="Path to the history database",
default=os.path.join(mem0_dir, "history.db"),
)
collection_name: str = Field(default="mem0", description="Name of the collection")
embedding_model_dims: int = Field(
default=1536, description="Dimensions of the embedding model"
)
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.embedding_model = OpenAIEmbedding()
# Initialize the appropriate vector store based on the configuration
vector_store_config = self.config.vector_store.config
if self.config.vector_store.provider == "qdrant":
self.vector_store = Qdrant(
host=vector_store_config.host,
port=vector_store_config.port,
path=vector_store_config.path,
url=vector_store_config.url,
api_key=vector_store_config.api_key,
)
else:
raise ValueError(
f"Unsupported vector store type: {self.config.vector_store_type}"
)
self.llm = OpenAILLM()
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.collection_name
self.vector_store.create_col(
name=self.collection_name, vector_size=self.embedding_model.dims
)
self.vector_store.create_col(
name=self.collection_name, vector_size=self.embedding_model.dims
)
capture_event("mem0.init", self)
@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
try:
config = MemoryConfig(**config_dict)
except ValidationError as e:
logging.error(f"Configuration validation error: {e}")
raise
return cls(config)
def add(
self,
data,
user_id=None,
agent_id=None,
run_id=None,
metadata=None,
filters=None,
):
"""
Create a new memory.
Args:
data (str): Data to store in the memory.
user_id (str, optional): ID of the user creating the memory. Defaults to None.
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
run_id (str, optional): ID of the run creating the memory. Defaults to None.
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
filters (dict, optional): Filters to apply to the search. Defaults to None.
Returns:
str: ID of the created memory.
"""
if metadata is None:
logging.warn("Metadata not provided. Using empty metadata.")
metadata = {}
embeddings = self.embedding_model.embed(data)
filters = filters or {}
if user_id:
filters["user_id"] = metadata["user_id"] = user_id
if agent_id:
filters["agent_id"] = metadata["agent_id"] = agent_id
if run_id:
filters["run_id"] = metadata["run_id"] = run_id
existing_memories = self.vector_store.search(
name=self.collection_name,
query=embeddings,
limit=5,
filters=filters,
)
existing_memories = [
MemoryItem(
id=mem.id,
score=mem.score,
metadata=mem.payload,
text=mem.payload["data"],
)
for mem in existing_memories
]
serialized_existing_memories = [
item.model_dump(include={"id", "text", "score"})
for item in existing_memories
]
logging.info(f"Total existing memories: {len(existing_memories)}")
messages = get_update_memory_messages(serialized_existing_memories, data)
# Add tools for noop, add, update, delete memory.
tools = [ADD_MEMORY_TOOL, UPDATE_MEMORY_TOOL, DELETE_MEMORY_TOOL]
response = self.llm.generate_response(messages=messages, tools=tools)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
response = []
if tool_calls:
# Create a new memory
available_functions = {
"add_memory": self._create_memory_tool,
"update_memory": self._update_memory_tool,
"delete_memory": self._delete_memory_tool,
}
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
logging.info(
f"[openai_func] func: {function_name}, args: {function_args}"
)
# Pass metadata to the function if it requires it
if function_name in ["add_memory", "update_memory"]:
function_args["metadata"] = metadata
function_result = function_to_call(**function_args)
# Fetch the memory_id from the response
response.append(
{
"id": function_result,
"event": function_name.replace("_memory", ""),
"data": function_args.get("data"),
}
)
capture_event(
"mem0.add.function_call",
self,
{"memory_id": function_result, "function_name": function_name},
)
capture_event("mem0.add", self)
return response
def get(self, memory_id):
"""
Retrieve a memory by ID.
Args:
memory_id (str): ID of the memory to retrieve.
Returns:
dict: Retrieved memory.
"""
capture_event("mem0.get", self, {"memory_id": memory_id})
memory = self.vector_store.get(name=self.collection_name, vector_id=memory_id)
if not memory:
return None
return MemoryItem(
id=memory.id,
metadata=memory.payload,
text=memory.payload["data"],
).model_dump(exclude={"score"})
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
"""
List all memories.
Returns:
list: List of all memories.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
memories = self.vector_store.list(
name=self.collection_name, filters=filters, limit=limit
)
return [
MemoryItem(
id=mem.id,
metadata=mem.payload,
text=mem.payload["data"],
).model_dump(exclude={"score"})
for mem in memories[0]
]
def search(
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
):
"""
Search for memories.
Args:
query (str): Query to search for.
user_id (str, optional): ID of the user to search for. Defaults to None.
agent_id (str, optional): ID of the agent to search for. Defaults to None.
run_id (str, optional): ID of the run to search for. Defaults to None.
limit (int, optional): Limit the number of results. Defaults to 100.
filters (dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: List of search results.
"""
filters = filters or {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit})
embeddings = self.embedding_model.embed(query)
memories = self.vector_store.search(
name=self.collection_name, query=embeddings, limit=limit, filters=filters
)
return [
MemoryItem(
id=mem.id,
metadata=mem.payload,
score=mem.score,
text=mem.payload["data"],
).model_dump()
for mem in memories
]
def update(self, memory_id, data):
"""
Update a memory by ID.
Args:
memory_id (str): ID of the memory to update.
data (dict): Data to update the memory with.
Returns:
dict: Updated memory.
"""
capture_event("mem0.get_all", self, {"memory_id": memory_id})
self._update_memory_tool(memory_id, data)
def delete(self, memory_id):
"""
Delete a memory by ID.
Args:
memory_id (str): ID of the memory to delete.
"""
capture_event("mem0.delete", self, {"memory_id": memory_id})
self._delete_memory_tool(memory_id)
def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
Delete all memories.
Args:
user_id (str, optional): ID of the user to delete memories for. Defaults to None.
agent_id (str, optional): ID of the agent to delete memories for. Defaults to None.
run_id (str, optional): ID of the run to delete memories for. Defaults to None.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not filters:
raise ValueError(
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)
capture_event("mem0.delete_all", self, {"filters": len(filters)})
memories = self.vector_store.list(name=self.collection_name, filters=filters)[0]
for memory in memories:
self._delete_memory_tool(memory.id)
def history(self, memory_id):
"""
Get the history of changes for a memory by ID.
Args:
memory_id (str): ID of the memory to get history for.
Returns:
list: List of changes for the memory.
"""
capture_event("mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)
def _create_memory_tool(self, data, metadata=None):
logging.info(f"Creating memory with {data=}")
embeddings = self.embedding_model.embed(data)
memory_id = str(uuid.uuid4())
metadata = metadata or {}
metadata["data"] = data
metadata["created_at"] = int(time.time())
self.vector_store.insert(
name=self.collection_name,
vectors=[embeddings],
ids=[memory_id],
payloads=[metadata],
)
self.db.add_history(memory_id, None, data, "add")
return memory_id
def _update_memory_tool(self, memory_id, data, metadata=None):
existing_memory = self.vector_store.get(
name=self.collection_name, vector_id=memory_id
)
prev_value = existing_memory.payload.get("data")
new_metadata = metadata or {}
new_metadata["data"] = data
new_metadata["updated_at"] = int(time.time())
embeddings = self.embedding_model.embed(data)
self.vector_store.update(
name=self.collection_name,
vector_id=memory_id,
vector=embeddings,
payload=new_metadata,
)
logging.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(memory_id, prev_value, data, "update")
def _delete_memory_tool(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
existing_memory = self.vector_store.get(
name=self.collection_name, vector_id=memory_id
)
prev_value = existing_memory.payload["data"]
self.vector_store.delete(name=self.collection_name, vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "delete", is_deleted=1)
def reset(self):
"""
Reset the memory store.
"""
self.vector_store.delete_col(name=self.collection_name)
self.db.reset()
capture_event("mem0.reset", self)
def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")

28
mem0/memory/setup.py Normal file
View File

@@ -0,0 +1,28 @@
import json
import os
import uuid
# Set up the directory path
home_dir = os.path.expanduser("~")
mem0_dir = os.path.join(home_dir, ".mem0")
os.makedirs(mem0_dir, exist_ok=True)
def setup_config():
config_path = os.path.join(mem0_dir, "config.json")
if not os.path.exists(config_path):
user_id = str(uuid.uuid4())
config = {"user_id": user_id}
with open(config_path, "w") as config_file:
json.dump(config, config_file, indent=4)
def get_user_id():
config_path = os.path.join(mem0_dir, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as config_file:
config = json.load(config_file)
user_id = config.get("user_id")
return user_id
else:
return "anonymous_user"

71
mem0/memory/storage.py Normal file
View File

@@ -0,0 +1,71 @@
import sqlite3
import uuid
from datetime import datetime
class SQLiteManager:
def __init__(self, db_path=":memory:"):
self.connection = sqlite3.connect(db_path, check_same_thread=False)
self._create_history_table()
def _create_history_table(self):
with self.connection:
self.connection.execute(
"""
CREATE TABLE IF NOT EXISTS history (
id TEXT PRIMARY KEY,
memory_id TEXT,
prev_value TEXT,
new_value TEXT,
event TEXT,
timestamp DATETIME,
is_deleted INTEGER
)
"""
)
def add_history(self, memory_id, prev_value, new_value, event, is_deleted=0):
with self.connection:
self.connection.execute(
"""
INSERT INTO history (id, memory_id, prev_value, new_value, event, timestamp, is_deleted)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
str(uuid.uuid4()),
memory_id,
prev_value,
new_value,
event,
datetime.utcnow(),
is_deleted,
),
)
def get_history(self, memory_id):
cursor = self.connection.execute(
"""
SELECT id, memory_id, prev_value, new_value, event, timestamp, is_deleted
FROM history
WHERE memory_id = ?
ORDER BY timestamp ASC
""",
(memory_id,),
)
rows = cursor.fetchall()
return [
{
"id": row[0],
"memory_id": row[1],
"prev_value": row[2],
"new_value": row[3],
"event": row[4],
"timestamp": row[5],
"is_deleted": row[6],
}
for row in rows
]
def reset(self):
with self.connection:
self.connection.execute("DROP TABLE IF EXISTS history")

61
mem0/memory/telemetry.py Normal file
View File

@@ -0,0 +1,61 @@
import platform
import sys
from posthog import Posthog
from mem0.memory.setup import get_user_id, setup_config
class AnonymousTelemetry:
def __init__(self, project_api_key, host):
self.posthog = Posthog(project_api_key=project_api_key, host=host)
# Call setup config to ensure that the user_id is generated
setup_config()
self.user_id = get_user_id()
def capture_event(self, event_name, properties=None):
if properties is None:
properties = {}
properties = {
"python_version": sys.version,
"os": sys.platform,
"os_version": platform.version(),
"os_release": platform.release(),
"processor": platform.processor(),
"machine": platform.machine(),
**properties,
}
self.posthog.capture(
distinct_id=self.user_id, event=event_name, properties=properties
)
def identify_user(self, user_id, properties=None):
if properties is None:
properties = {}
self.posthog.identify(distinct_id=user_id, properties=properties)
def close(self):
self.posthog.shutdown()
# Initialize AnonymousTelemetry
telemetry = AnonymousTelemetry(
project_api_key="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX",
host="https://us.i.posthog.com",
)
def capture_event(event_name, memory_instance, additional_data=None):
event_data = {
"collection": memory_instance.collection_name,
"vector_size": memory_instance.embedding_model.dims,
"history_store": "sqlite",
"vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
"llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
"embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}",
}
if additional_data:
event_data.update(additional_data)
telemetry.capture_event(event_name, event_data)

14
mem0/memory/utils.py Normal file
View File

@@ -0,0 +1,14 @@
from mem0.configs.prompts import UPDATE_MEMORY_PROMPT
def get_update_memory_prompt(existing_memories, memory, template=UPDATE_MEMORY_PROMPT):
return template.format(existing_memories=existing_memories, memory=memory)
def get_update_memory_messages(existing_memories, memory):
return [
{
"role": "user",
"content": get_update_memory_prompt(existing_memories, memory),
},
]

View File

View File

@@ -0,0 +1,48 @@
from abc import ABC, abstractmethod
class VectorStoreBase(ABC):
@abstractmethod
def create_col(self, name, vector_size, distance):
"""Create a new collection."""
pass
@abstractmethod
def insert(self, name, vectors, payloads=None, ids=None):
"""Insert vectors into a collection."""
pass
@abstractmethod
def search(self, name, query, limit=5, filters=None):
"""Search for similar vectors."""
pass
@abstractmethod
def delete(self, name, vector_id):
"""Delete a vector by ID."""
pass
@abstractmethod
def update(self, name, vector_id, vector=None, payload=None):
"""Update a vector and its payload."""
pass
@abstractmethod
def get(self, name, vector_id):
"""Retrieve a vector by ID."""
pass
@abstractmethod
def list_cols(self):
"""List all collections."""
pass
@abstractmethod
def delete_col(self, name):
"""Delete a collection."""
pass
@abstractmethod
def col_info(self, name):
"""Get information about a collection."""
pass

View File

@@ -0,0 +1,45 @@
from typing import Optional
from pydantic import BaseModel, Field, field_validator, model_validator
class QdrantConfig(BaseModel):
host: Optional[str] = Field(None, description="Host address for Qdrant server")
port: Optional[int] = Field(None, description="Port for Qdrant server")
path: Optional[str] = Field(None, description="Path for local Qdrant database")
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
@model_validator(mode="before")
def check_host_port_or_path(cls, values):
host, port, path, url, api_key = (
values.get("host"),
values.get("port"),
values.get("path"),
values.get("url"),
values.get("api_key"),
)
if not path and not (host and port) and not (url and api_key):
raise ValueError(
"Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided."
)
return values
class VectorStoreConfig(BaseModel):
provider: str = Field(
description="Provider of the vector store (e.g., 'qdrant', 'chromadb', 'elasticsearch')",
default="qdrant",
)
config: QdrantConfig = Field(
description="Configuration for the specific vector store",
default=QdrantConfig(path="/tmp/qdrant"),
)
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider == "qdrant":
return QdrantConfig(**v.model_dump())
else:
raise ValueError(f"Unsupported vector store provider: {provider}")

View File

@@ -0,0 +1,242 @@
import logging
from typing import Optional
from pydantic import BaseModel, Field
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
FieldCondition,
Filter,
MatchValue,
PointIdsList,
PointStruct,
Range,
VectorParams,
)
from mem0.vector_stores.base import VectorStoreBase
class QdrantConfig(BaseModel):
host: Optional[str] = Field(None, description="Host address for Qdrant server")
port: Optional[int] = Field(None, description="Port for Qdrant server")
path: Optional[str] = Field(None, description="Path for local Qdrant database")
class Qdrant(VectorStoreBase):
def __init__(
self,
client=None,
host="localhost",
port=6333,
path=None,
url=None,
api_key=None,
):
"""
Initialize the Qdrant vector store.
Args:
client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
host (str, optional): Host address for Qdrant server. Defaults to "localhost".
port (int, optional): Port for Qdrant server. Defaults to 6333.
path (str, optional): Path for local Qdrant database. Defaults to None.
url (str, optional): Full URL for Qdrant server. Defaults to None.
api_key (str, optional): API key for Qdrant server. Defaults to None.
"""
if client:
self.client = client
else:
params = {}
if path:
params["path"] = path
if api_key:
params["api_key"] = api_key
if url:
params["url"] = url
if host and port:
params["host"] = host
params["port"] = port
self.client = QdrantClient(**params)
def create_col(self, name, vector_size, distance=Distance.COSINE):
"""
Create a new collection.
Args:
name (str): Name of the collection.
vector_size (int): Size of the vectors to be stored.
distance (Distance, optional): Distance metric for vector similarity. Defaults to Distance.COSINE.
"""
# Skip creating collection if already exists
response = self.list_cols()
for collection in response.collections:
if collection.name == name:
logging.debug(f"Collection {name} already exists. Skipping creation.")
return
self.client.create_collection(
collection_name=name,
vectors_config=VectorParams(size=vector_size, distance=distance),
)
def insert(self, name, vectors, payloads=None, ids=None):
"""
Insert vectors into a collection.
Args:
name (str): Name of the collection.
vectors (list): List of vectors to insert.
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
"""
points = [
PointStruct(
id=idx if ids is None else ids[idx],
vector=vector,
payload=payloads[idx] if payloads else {},
)
for idx, vector in enumerate(vectors)
]
self.client.upsert(collection_name=name, points=points)
def _create_filter(self, filters):
"""
Create a Filter object from the provided filters.
Args:
filters (dict): Filters to apply.
Returns:
Filter: The created Filter object.
"""
conditions = []
for key, value in filters.items():
if isinstance(value, dict) and "gte" in value and "lte" in value:
conditions.append(
FieldCondition(
key=key, range=Range(gte=value["gte"], lte=value["lte"])
)
)
else:
conditions.append(
FieldCondition(key=key, match=MatchValue(value=value))
)
return Filter(must=conditions) if conditions else None
def search(self, name, query, limit=5, filters=None):
"""
Search for similar vectors.
Args:
name (str): Name of the collection.
query (list): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: Search results.
"""
query_filter = self._create_filter(filters) if filters else None
hits = self.client.search(
collection_name=name,
query_vector=query,
query_filter=query_filter,
limit=limit,
)
return hits
def delete(self, name, vector_id):
"""
Delete a vector by ID.
Args:
name (str): Name of the collection.
vector_id (int): ID of the vector to delete.
"""
self.client.delete(
collection_name=name,
points_selector=PointIdsList(
points=[vector_id],
),
)
def update(self, name, vector_id, vector=None, payload=None):
"""
Update a vector and its payload.
Args:
name (str): Name of the collection.
vector_id (int): ID of the vector to update.
vector (list, optional): Updated vector. Defaults to None.
payload (dict, optional): Updated payload. Defaults to None.
"""
point = PointStruct(id=vector_id, vector=vector, payload=payload)
self.client.upsert(collection_name=name, points=[point])
def get(self, name, vector_id):
"""
Retrieve a vector by ID.
Args:
name (str): Name of the collection.
vector_id (int): ID of the vector to retrieve.
Returns:
dict: Retrieved vector.
"""
result = self.client.retrieve(
collection_name=name, ids=[vector_id], with_payload=True
)
return result[0] if result else None
def list_cols(self):
"""
List all collections.
Returns:
list: List of collection names.
"""
return self.client.get_collections()
def delete_col(self, name):
"""
Delete a collection.
Args:
name (str): Name of the collection to delete.
"""
self.client.delete_collection(collection_name=name)
def col_info(self, name):
"""
Get information about a collection.
Args:
name (str): Name of the collection.
Returns:
dict: Collection information.
"""
return self.client.get_collection(collection_name=name)
def list(self, name, filters=None, limit=100):
"""
List all vectors in a collection.
Args:
name (str): Name of the collection.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
list: List of vectors.
"""
query_filter = self._create_filter(filters) if filters else None
result = self.client.scroll(
collection_name=name,
scroll_filter=query_filter,
limit=limit,
with_payload=True,
with_vectors=False,
)
return result