+MongoDB Vector Support (#2367)
Co-authored-by: Divya Gupta <divya.gupta@mongodb.com>
This commit is contained in:
42
mem0/configs/vector_stores/mongodb.py
Normal file
42
mem0/configs/vector_stores/mongodb.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Any, Dict, Optional, Callable, List
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
|
||||
class MongoVectorConfig(BaseModel):
|
||||
"""Configuration for MongoDB vector database."""
|
||||
|
||||
db_name: str = Field("mem0_db", description="Name of the MongoDB database")
|
||||
collection_name: str = Field("mem0", description="Name of the MongoDB collection")
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
|
||||
user: Optional[str] = Field(None, description="MongoDB user for authentication")
|
||||
password: Optional[str] = Field(None, description="Password for the MongoDB user")
|
||||
host: Optional[str] = Field("localhost", description="MongoDB host. Default is 'localhost'")
|
||||
port: Optional[int] = Field(27017, description="MongoDB port. Default is 27017")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_auth_and_connection(cls, values):
|
||||
user = values.get("user")
|
||||
password = values.get("password")
|
||||
if (user is None) != (password is None):
|
||||
raise ValueError("Both 'user' and 'password' must be provided together or omitted together.")
|
||||
|
||||
host = values.get("host")
|
||||
port = values.get("port")
|
||||
if host is None:
|
||||
raise ValueError("The 'host' must be provided.")
|
||||
if port is None:
|
||||
raise ValueError("The 'port' must be provided.")
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.__fields__)
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
||||
f"Please provide only the following fields: {', '.join(allowed_fields)}."
|
||||
)
|
||||
return values
|
||||
@@ -79,6 +79,7 @@ class VectorStoreFactory:
|
||||
"upstash_vector": "mem0.vector_stores.upstash_vector.UpstashVector",
|
||||
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
|
||||
"pinecone": "mem0.vector_stores.pinecone.PineconeDB",
|
||||
"mongodb": "mem0.vector_stores.mongodb.MongoDB",
|
||||
"redis": "mem0.vector_stores.redis.RedisDB",
|
||||
"elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB",
|
||||
"vertex_ai_vector_search": "mem0.vector_stores.vertex_ai_vector_search.GoogleMatchingEngine",
|
||||
|
||||
@@ -15,6 +15,7 @@ class VectorStoreConfig(BaseModel):
|
||||
"chroma": "ChromaDbConfig",
|
||||
"pgvector": "PGVectorConfig",
|
||||
"pinecone": "PineconeConfig",
|
||||
"mongodb": "MongoDBConfig",
|
||||
"milvus": "MilvusDBConfig",
|
||||
"upstash_vector": "UpstashVectorConfig",
|
||||
"azure_ai_search": "AzureAISearchConfig",
|
||||
|
||||
299
mem0/vector_stores/mongodb.py
Normal file
299
mem0/vector_stores/mongodb.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any, Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
except ImportError:
|
||||
raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.")
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
payload: Optional[dict]
|
||||
|
||||
|
||||
class MongoVector(VectorStoreBase):
|
||||
VECTOR_TYPE = "knnVector"
|
||||
SIMILARITY_METRIC = "cosine"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_name: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
mongo_uri: str
|
||||
):
|
||||
"""
|
||||
Initialize the MongoDB vector store with vector search capabilities.
|
||||
|
||||
Args:
|
||||
db_name (str): Database name
|
||||
collection_name (str): Collection name
|
||||
embedding_model_dims (int): Dimension of the embedding vector
|
||||
mongo_uri (str): MongoDB connection URI
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.db_name = db_name
|
||||
|
||||
self.client = MongoClient(
|
||||
mongo_uri
|
||||
)
|
||||
self.db = self.client[db_name]
|
||||
self.collection = self.create_col()
|
||||
|
||||
def create_col(self):
|
||||
"""Create new collection with vector search index."""
|
||||
try:
|
||||
database = self.client[self.db_name]
|
||||
collection_names = database.list_collection_names()
|
||||
if self.collection_name not in collection_names:
|
||||
logger.info(f"Collection '{self.collection_name}' does not exist. Creating it now.")
|
||||
collection = database[self.collection_name]
|
||||
# Insert and remove a placeholder document to create the collection
|
||||
collection.insert_one({"_id": 0, "placeholder": True})
|
||||
collection.delete_one({"_id": 0})
|
||||
logger.info(f"Collection '{self.collection_name}' created successfully.")
|
||||
else:
|
||||
collection = database[self.collection_name]
|
||||
|
||||
self.index_name = f"{self.collection_name}_vector_index"
|
||||
found_indexes = list(collection.list_search_indexes(name=self.index_name))
|
||||
if found_indexes:
|
||||
logger.info(f"Search index '{self.index_name}' already exists in collection '{self.collection_name}'.")
|
||||
else:
|
||||
search_index_model = SearchIndexModel(
|
||||
name=self.index_name,
|
||||
definition={
|
||||
"mappings": {
|
||||
"dynamic": False,
|
||||
"fields": {
|
||||
"embedding": {
|
||||
"type": self.VECTOR_TYPE,
|
||||
"dimensions": self.embedding_model_dims,
|
||||
"similarity": self.SIMILARITY_METRIC,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
collection.create_search_index(search_index_model)
|
||||
logger.info(
|
||||
f"Search index '{self.index_name}' created successfully for collection '{self.collection_name}'."
|
||||
)
|
||||
return collection
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error creating collection and search index: {e}")
|
||||
return None
|
||||
|
||||
def insert(
|
||||
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Insert vectors into the collection.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert.
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection '{self.collection_name}'.")
|
||||
|
||||
data = []
|
||||
for vector, payload, _id in zip(vectors, payloads or [{}] * len(vectors), ids or [None] * len(vectors)):
|
||||
document = {"_id": _id, "embedding": vector, "payload": payload}
|
||||
data.append(document)
|
||||
try:
|
||||
self.collection.insert_many(data)
|
||||
logger.info(f"Inserted {len(data)} documents into '{self.collection_name}'.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error inserting data: {e}")
|
||||
|
||||
def search(self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors using the vector search index.
|
||||
|
||||
Args:
|
||||
query (str): Query string
|
||||
query_vector (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
|
||||
found_indexes = list(self.collection.list_search_indexes(name=self.index_name))
|
||||
if not found_indexes:
|
||||
logger.error(f"Index '{self.index_name}' does not exist.")
|
||||
return []
|
||||
|
||||
results = []
|
||||
try:
|
||||
collection = self.client[self.db_name][self.collection_name]
|
||||
pipeline = [
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"index": self.index_name,
|
||||
"limit": limit,
|
||||
"numCandidates": limit,
|
||||
"queryVector": query_vector,
|
||||
"path": "embedding",
|
||||
}
|
||||
},
|
||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||
{"$project": {"embedding": 0}},
|
||||
]
|
||||
results = list(collection.aggregate(pipeline))
|
||||
logger.info(f"Vector search completed. Found {len(results)} documents.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during vector search for query {query}: {e}")
|
||||
return []
|
||||
|
||||
output = [OutputData(id=str(doc["_id"]), score=doc.get("score"), payload=doc.get("payload")) for doc in results]
|
||||
return output
|
||||
|
||||
def delete(self, vector_id: str) -> None:
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
try:
|
||||
result = self.collection.delete_one({"_id": vector_id})
|
||||
if result.deleted_count > 0:
|
||||
logger.info(f"Deleted document with ID '{vector_id}'.")
|
||||
else:
|
||||
logger.warning(f"No document found with ID '{vector_id}' to delete.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error deleting document: {e}")
|
||||
|
||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (List[float], optional): Updated vector.
|
||||
payload (Dict, optional): Updated payload.
|
||||
"""
|
||||
update_fields = {}
|
||||
if vector is not None:
|
||||
update_fields["embedding"] = vector
|
||||
if payload is not None:
|
||||
update_fields["payload"] = payload
|
||||
|
||||
if update_fields:
|
||||
try:
|
||||
result = self.collection.update_one({"_id": vector_id}, {"$set": update_fields})
|
||||
if result.matched_count > 0:
|
||||
logger.info(f"Updated document with ID '{vector_id}'.")
|
||||
else:
|
||||
logger.warning(f"No document found with ID '{vector_id}' to update.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error updating document: {e}")
|
||||
|
||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[OutputData]: Retrieved vector or None if not found.
|
||||
"""
|
||||
try:
|
||||
doc = self.collection.find_one({"_id": vector_id})
|
||||
if doc:
|
||||
logger.info(f"Retrieved document with ID '{vector_id}'.")
|
||||
return OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload"))
|
||||
else:
|
||||
logger.warning(f"Document with ID '{vector_id}' not found.")
|
||||
return None
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error retrieving document: {e}")
|
||||
return None
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections in the database.
|
||||
|
||||
Returns:
|
||||
List[str]: List of collection names.
|
||||
"""
|
||||
try:
|
||||
collections = self.db.list_collection_names()
|
||||
logger.info(f"Listing collections in database '{self.db_name}': {collections}")
|
||||
return collections
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error listing collections: {e}")
|
||||
return []
|
||||
|
||||
def delete_col(self) -> None:
|
||||
"""Delete the collection."""
|
||||
try:
|
||||
self.collection.drop()
|
||||
logger.info(f"Deleted collection '{self.collection_name}'.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error deleting collection: {e}")
|
||||
|
||||
def col_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the collection.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Collection information.
|
||||
"""
|
||||
try:
|
||||
stats = self.db.command("collstats", self.collection_name)
|
||||
info = {"name": self.collection_name, "count": stats.get("count"), "size": stats.get("size")}
|
||||
logger.info(f"Collection info: {info}")
|
||||
return info
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error getting collection info: {e}")
|
||||
return {}
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
"""
|
||||
List vectors in the collection.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply to the list.
|
||||
limit (int, optional): Number of vectors to return.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
try:
|
||||
query = filters or {}
|
||||
cursor = self.collection.find(query).limit(limit)
|
||||
results = [OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload")) for doc in cursor]
|
||||
logger.info(f"Retrieved {len(results)} documents from collection '{self.collection_name}'.")
|
||||
return results
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error listing documents: {e}")
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.collection = self.create_col(self.collection_name)
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Close the database connection when the object is deleted."""
|
||||
if hasattr(self, "client"):
|
||||
self.client.close()
|
||||
logger.info("MongoClient connection closed.")
|
||||
Reference in New Issue
Block a user