From a0cd4065d9240329cb9b20acac851d3e3d5c69d5 Mon Sep 17 00:00:00 2001 From: Fabian Valle Date: Sat, 14 Jun 2025 08:27:06 -0400 Subject: [PATCH] +MongoDB Vector Support (#2367) Co-authored-by: Divya Gupta --- docs/components/vectordbs/dbs/mongodb.mdx | 49 ++++ docs/components/vectordbs/overview.mdx | 1 + docs/docs.json | 1 + mem0/configs/vector_stores/mongodb.py | 42 +++ mem0/utils/factory.py | 1 + mem0/vector_stores/configs.py | 1 + mem0/vector_stores/mongodb.py | 299 ++++++++++++++++++++++ poetry.lock | 2 +- tests/vector_stores/test_mongodb.py | 176 +++++++++++++ 9 files changed, 571 insertions(+), 1 deletion(-) create mode 100644 docs/components/vectordbs/dbs/mongodb.mdx create mode 100644 mem0/configs/vector_stores/mongodb.py create mode 100644 mem0/vector_stores/mongodb.py create mode 100644 tests/vector_stores/test_mongodb.py diff --git a/docs/components/vectordbs/dbs/mongodb.mdx b/docs/components/vectordbs/dbs/mongodb.mdx new file mode 100644 index 00000000..216c0436 --- /dev/null +++ b/docs/components/vectordbs/dbs/mongodb.mdx @@ -0,0 +1,49 @@ +# MongoDB + +[MongoDB](https://www.mongodb.com/) is a versatile document database that supports vector search capabilities, allowing for efficient high-dimensional similarity searches over large datasets with robust scalability and performance. + +## Usage + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" + +config = { + "vector_store": { + "provider": "mongodb", + "config": { + "db_name": "mem0-db", + "collection_name": "mem0-collection", + "user": "my-user", + "password": "my-password", + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alice", metadata={"category": "movies"}) +``` + +## Config + +Here are the parameters available for configuring MongoDB: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| db_name | Name of the MongoDB database | `"mem0_db"` | +| collection_name | Name of the MongoDB collection | `"mem0_collection"` | +| embedding_model_dims | Dimensions of the embedding vectors | `1536` | +| user | MongoDB user for authentication | `None` | +| password | Password for the MongoDB user | `None` | +| host | MongoDB host | `"localhost"` | +| port | MongoDB port | `27017` | + +> **Note**: `user` and `password` must either be provided together or omitted together. diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx index f0d87be1..3309b7a4 100644 --- a/docs/components/vectordbs/overview.mdx +++ b/docs/components/vectordbs/overview.mdx @@ -23,6 +23,7 @@ See the list of supported vector databases below. + diff --git a/docs/docs.json b/docs/docs.json index 65ae9590..60e8d2ab 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -137,6 +137,7 @@ "components/vectordbs/dbs/pgvector", "components/vectordbs/dbs/milvus", "components/vectordbs/dbs/pinecone", + "components/vectordbs/dbs/mongodb", "components/vectordbs/dbs/azure", "components/vectordbs/dbs/redis", "components/vectordbs/dbs/elasticsearch", diff --git a/mem0/configs/vector_stores/mongodb.py b/mem0/configs/vector_stores/mongodb.py new file mode 100644 index 00000000..3b6ce881 --- /dev/null +++ b/mem0/configs/vector_stores/mongodb.py @@ -0,0 +1,42 @@ +from typing import Any, Dict, Optional, Callable, List + +from pydantic import BaseModel, Field, root_validator + + +class MongoVectorConfig(BaseModel): + """Configuration for MongoDB vector database.""" + + db_name: str = Field("mem0_db", description="Name of the MongoDB database") + collection_name: str = Field("mem0", description="Name of the MongoDB collection") + embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors") + user: Optional[str] = Field(None, description="MongoDB user for authentication") + password: Optional[str] = Field(None, description="Password for the MongoDB user") + host: Optional[str] = Field("localhost", description="MongoDB host. Default is 'localhost'") + port: Optional[int] = Field(27017, description="MongoDB port. Default is 27017") + + @root_validator(pre=True) + def check_auth_and_connection(cls, values): + user = values.get("user") + password = values.get("password") + if (user is None) != (password is None): + raise ValueError("Both 'user' and 'password' must be provided together or omitted together.") + + host = values.get("host") + port = values.get("port") + if host is None: + raise ValueError("The 'host' must be provided.") + if port is None: + raise ValueError("The 'port' must be provided.") + return values + + @root_validator(pre=True) + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.__fields__) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. " + f"Please provide only the following fields: {', '.join(allowed_fields)}." + ) + return values diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index d137e273..4988b300 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -79,6 +79,7 @@ class VectorStoreFactory: "upstash_vector": "mem0.vector_stores.upstash_vector.UpstashVector", "azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch", "pinecone": "mem0.vector_stores.pinecone.PineconeDB", + "mongodb": "mem0.vector_stores.mongodb.MongoDB", "redis": "mem0.vector_stores.redis.RedisDB", "elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB", "vertex_ai_vector_search": "mem0.vector_stores.vertex_ai_vector_search.GoogleMatchingEngine", diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index 43a2289f..e360d238 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -15,6 +15,7 @@ class VectorStoreConfig(BaseModel): "chroma": "ChromaDbConfig", "pgvector": "PGVectorConfig", "pinecone": "PineconeConfig", + "mongodb": "MongoDBConfig", "milvus": "MilvusDBConfig", "upstash_vector": "UpstashVectorConfig", "azure_ai_search": "AzureAISearchConfig", diff --git a/mem0/vector_stores/mongodb.py b/mem0/vector_stores/mongodb.py new file mode 100644 index 00000000..0a225fc5 --- /dev/null +++ b/mem0/vector_stores/mongodb.py @@ -0,0 +1,299 @@ +import logging +from typing import List, Optional, Dict, Any, Callable + +from pydantic import BaseModel + +try: + from pymongo import MongoClient + from pymongo.operations import SearchIndexModel + from pymongo.errors import PyMongoError +except ImportError: + raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.") + +from mem0.vector_stores.base import VectorStoreBase + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +class OutputData(BaseModel): + id: Optional[str] + score: Optional[float] + payload: Optional[dict] + + +class MongoVector(VectorStoreBase): + VECTOR_TYPE = "knnVector" + SIMILARITY_METRIC = "cosine" + + def __init__( + self, + db_name: str, + collection_name: str, + embedding_model_dims: int, + mongo_uri: str + ): + """ + Initialize the MongoDB vector store with vector search capabilities. + + Args: + db_name (str): Database name + collection_name (str): Collection name + embedding_model_dims (int): Dimension of the embedding vector + mongo_uri (str): MongoDB connection URI + """ + self.collection_name = collection_name + self.embedding_model_dims = embedding_model_dims + self.db_name = db_name + + self.client = MongoClient( + mongo_uri + ) + self.db = self.client[db_name] + self.collection = self.create_col() + + def create_col(self): + """Create new collection with vector search index.""" + try: + database = self.client[self.db_name] + collection_names = database.list_collection_names() + if self.collection_name not in collection_names: + logger.info(f"Collection '{self.collection_name}' does not exist. Creating it now.") + collection = database[self.collection_name] + # Insert and remove a placeholder document to create the collection + collection.insert_one({"_id": 0, "placeholder": True}) + collection.delete_one({"_id": 0}) + logger.info(f"Collection '{self.collection_name}' created successfully.") + else: + collection = database[self.collection_name] + + self.index_name = f"{self.collection_name}_vector_index" + found_indexes = list(collection.list_search_indexes(name=self.index_name)) + if found_indexes: + logger.info(f"Search index '{self.index_name}' already exists in collection '{self.collection_name}'.") + else: + search_index_model = SearchIndexModel( + name=self.index_name, + definition={ + "mappings": { + "dynamic": False, + "fields": { + "embedding": { + "type": self.VECTOR_TYPE, + "dimensions": self.embedding_model_dims, + "similarity": self.SIMILARITY_METRIC, + } + }, + } + }, + ) + collection.create_search_index(search_index_model) + logger.info( + f"Search index '{self.index_name}' created successfully for collection '{self.collection_name}'." + ) + return collection + except PyMongoError as e: + logger.error(f"Error creating collection and search index: {e}") + return None + + def insert( + self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None + ) -> None: + """ + Insert vectors into the collection. + + Args: + vectors (List[List[float]]): List of vectors to insert. + payloads (List[Dict], optional): List of payloads corresponding to vectors. + ids (List[str], optional): List of IDs corresponding to vectors. + """ + logger.info(f"Inserting {len(vectors)} vectors into collection '{self.collection_name}'.") + + data = [] + for vector, payload, _id in zip(vectors, payloads or [{}] * len(vectors), ids or [None] * len(vectors)): + document = {"_id": _id, "embedding": vector, "payload": payload} + data.append(document) + try: + self.collection.insert_many(data) + logger.info(f"Inserted {len(data)} documents into '{self.collection_name}'.") + except PyMongoError as e: + logger.error(f"Error inserting data: {e}") + + def search(self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]: + """ + Search for similar vectors using the vector search index. + + Args: + query (str): Query string + query_vector (List[float]): Query vector. + limit (int, optional): Number of results to return. Defaults to 5. + filters (Dict, optional): Filters to apply to the search. + + Returns: + List[OutputData]: Search results. + """ + + found_indexes = list(self.collection.list_search_indexes(name=self.index_name)) + if not found_indexes: + logger.error(f"Index '{self.index_name}' does not exist.") + return [] + + results = [] + try: + collection = self.client[self.db_name][self.collection_name] + pipeline = [ + { + "$vectorSearch": { + "index": self.index_name, + "limit": limit, + "numCandidates": limit, + "queryVector": query_vector, + "path": "embedding", + } + }, + {"$set": {"score": {"$meta": "vectorSearchScore"}}}, + {"$project": {"embedding": 0}}, + ] + results = list(collection.aggregate(pipeline)) + logger.info(f"Vector search completed. Found {len(results)} documents.") + except Exception as e: + logger.error(f"Error during vector search for query {query}: {e}") + return [] + + output = [OutputData(id=str(doc["_id"]), score=doc.get("score"), payload=doc.get("payload")) for doc in results] + return output + + def delete(self, vector_id: str) -> None: + """ + Delete a vector by ID. + + Args: + vector_id (str): ID of the vector to delete. + """ + try: + result = self.collection.delete_one({"_id": vector_id}) + if result.deleted_count > 0: + logger.info(f"Deleted document with ID '{vector_id}'.") + else: + logger.warning(f"No document found with ID '{vector_id}' to delete.") + except PyMongoError as e: + logger.error(f"Error deleting document: {e}") + + def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None: + """ + Update a vector and its payload. + + Args: + vector_id (str): ID of the vector to update. + vector (List[float], optional): Updated vector. + payload (Dict, optional): Updated payload. + """ + update_fields = {} + if vector is not None: + update_fields["embedding"] = vector + if payload is not None: + update_fields["payload"] = payload + + if update_fields: + try: + result = self.collection.update_one({"_id": vector_id}, {"$set": update_fields}) + if result.matched_count > 0: + logger.info(f"Updated document with ID '{vector_id}'.") + else: + logger.warning(f"No document found with ID '{vector_id}' to update.") + except PyMongoError as e: + logger.error(f"Error updating document: {e}") + + def get(self, vector_id: str) -> Optional[OutputData]: + """ + Retrieve a vector by ID. + + Args: + vector_id (str): ID of the vector to retrieve. + + Returns: + Optional[OutputData]: Retrieved vector or None if not found. + """ + try: + doc = self.collection.find_one({"_id": vector_id}) + if doc: + logger.info(f"Retrieved document with ID '{vector_id}'.") + return OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload")) + else: + logger.warning(f"Document with ID '{vector_id}' not found.") + return None + except PyMongoError as e: + logger.error(f"Error retrieving document: {e}") + return None + + def list_cols(self) -> List[str]: + """ + List all collections in the database. + + Returns: + List[str]: List of collection names. + """ + try: + collections = self.db.list_collection_names() + logger.info(f"Listing collections in database '{self.db_name}': {collections}") + return collections + except PyMongoError as e: + logger.error(f"Error listing collections: {e}") + return [] + + def delete_col(self) -> None: + """Delete the collection.""" + try: + self.collection.drop() + logger.info(f"Deleted collection '{self.collection_name}'.") + except PyMongoError as e: + logger.error(f"Error deleting collection: {e}") + + def col_info(self) -> Dict[str, Any]: + """ + Get information about the collection. + + Returns: + Dict[str, Any]: Collection information. + """ + try: + stats = self.db.command("collstats", self.collection_name) + info = {"name": self.collection_name, "count": stats.get("count"), "size": stats.get("size")} + logger.info(f"Collection info: {info}") + return info + except PyMongoError as e: + logger.error(f"Error getting collection info: {e}") + return {} + + def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]: + """ + List vectors in the collection. + + Args: + filters (Dict, optional): Filters to apply to the list. + limit (int, optional): Number of vectors to return. + + Returns: + List[OutputData]: List of vectors. + """ + try: + query = filters or {} + cursor = self.collection.find(query).limit(limit) + results = [OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload")) for doc in cursor] + logger.info(f"Retrieved {len(results)} documents from collection '{self.collection_name}'.") + return results + except PyMongoError as e: + logger.error(f"Error listing documents: {e}") + return [] + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting index {self.collection_name}...") + self.delete_col() + self.collection = self.create_col(self.collection_name) + + def __del__(self) -> None: + """Close the database connection when the object is deleted.""" + if hasattr(self, "client"): + self.client.close() + logger.info("MongoClient connection closed.") diff --git a/poetry.lock b/poetry.lock index fd05427e..8784eec3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2200,4 +2200,4 @@ graph = ["langchain-neo4j", "neo4j", "rank-bm25"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<4.0" -content-hash = "07f2aee9c596c2d2470df085b92551b7b7e3c19cabe61ae5bee7505395601417" +content-hash = "07f2aee9c596c2d2470df085b92551b7b7e3c19cabe61ae5bee7505395601417" \ No newline at end of file diff --git a/tests/vector_stores/test_mongodb.py b/tests/vector_stores/test_mongodb.py new file mode 100644 index 00000000..370d5322 --- /dev/null +++ b/tests/vector_stores/test_mongodb.py @@ -0,0 +1,176 @@ +import time +import pytest +from unittest.mock import MagicMock, patch +from mem0.vector_stores.mongodb import MongoVector +from pymongo.operations import SearchIndexModel + +@pytest.fixture +@patch("mem0.vector_stores.mongodb.MongoClient") +def mongo_vector_fixture(mock_mongo_client): + mock_client = mock_mongo_client.return_value + mock_db = mock_client["test_db"] + mock_collection = mock_db["test_collection"] + mock_collection.list_search_indexes.return_value = [] + mock_collection.aggregate.return_value = [] + mock_collection.find_one.return_value = None + mock_collection.find.return_value = [] + mock_db.list_collection_names.return_value = [] + + mongo_vector = MongoVector( + db_name="test_db", + collection_name="test_collection", + embedding_model_dims=1536, + user="username", + password="password", + ) + return mongo_vector, mock_collection, mock_db + +def test_initalize_create_col(mongo_vector_fixture): + mongo_vector, mock_collection, mock_db = mongo_vector_fixture + assert mongo_vector.collection_name == "test_collection" + assert mongo_vector.embedding_model_dims == 1536 + assert mongo_vector.db_name == "test_db" + + # Verify create_col being called + mock_db.list_collection_names.assert_called_once() + mock_collection.insert_one.assert_called_once_with({"_id": 0, "placeholder": True}) + mock_collection.delete_one.assert_called_once_with({"_id": 0}) + assert mongo_vector.index_name == "test_collection_vector_index" + mock_collection.list_search_indexes.assert_called_once_with(name="test_collection_vector_index") + mock_collection.create_search_index.assert_called_once() + args, _ = mock_collection.create_search_index.call_args + search_index_model = args[0].document + assert search_index_model == { + "name": "test_collection_vector_index", + "definition": { + "mappings": { + "dynamic": False, + "fields": { + "embedding": { + "type": "knnVector", + "d": 1536, + "similarity": "cosine", + } + } + } + } + } + assert mongo_vector.collection == mock_collection + +def test_insert(mongo_vector_fixture): + mongo_vector, mock_collection, _ = mongo_vector_fixture + vectors = [[0.1] * 1536, [0.2] * 1536] + payloads = [{"name": "vector1"}, {"name": "vector2"}] + ids = ["id1", "id2"] + + mongo_vector.insert(vectors, payloads, ids) + expected_records=[ + ({"_id": ids[0], "embedding": vectors[0], "payload": payloads[0]}), + ({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]}) + ] + mock_collection.insert_many.assert_called_once_with(expected_records) + +def test_search(mongo_vector_fixture): + mongo_vector, mock_collection, _ = mongo_vector_fixture + query_vector = [0.1] * 1536 + mock_collection.aggregate.return_value = [ + {"_id": "id1", "score": 0.9, "payload": {"key": "value1"}}, + {"_id": "id2", "score": 0.8, "payload": {"key": "value2"}}, + ] + mock_collection.list_search_indexes.return_value = ["test_collection_vector_index"] + + results = mongo_vector.search("query_str", query_vector, limit=2) + mock_collection.list_search_indexes.assert_called_with(name="test_collection_vector_index") + mock_collection.aggregate.assert_called_once_with([ + { + "$vectorSearch": { + "index": "test_collection_vector_index", + "limit": 2, + "numCandidates": 2, + "queryVector": query_vector, + "path": "embedding", + }, + }, + {"$set": {"score": {"$meta": "vectorSearchScore"}}}, + {"$project": {"embedding": 0}}, + ]) + assert len(results) == 2 + assert results[0].id == "id1" + assert results[0].score == 0.9 + assert results[1].id == "id2" + assert results[1].score == 0.8 + +def test_delete(mongo_vector_fixture): + mongo_vector, mock_collection, _ = mongo_vector_fixture + mock_delete_result = MagicMock() + mock_delete_result.deleted_count = 1 + mock_collection.delete_one.return_value = mock_delete_result + + mongo_vector.delete("id1") + mock_collection.delete_one.assert_called_with({"_id": "id1"}) + +def test_update(mongo_vector_fixture): + mongo_vector, mock_collection, _ = mongo_vector_fixture + mock_update_result = MagicMock() + mock_update_result.matched_count = 1 + mock_collection.update_one.return_value = mock_update_result + idValue = "id1" + vectorValue = [0.2] * 1536 + payloadValue = {"key": "updated"} + + mongo_vector.update(idValue, vector=vectorValue, payload=payloadValue) + mock_collection.update_one.assert_called_once_with( + {"_id": idValue}, + {"$set": {"embedding": vectorValue, "payload": payloadValue}}, + ) + +def test_get(mongo_vector_fixture): + mongo_vector, mock_collection, _ = mongo_vector_fixture + mock_collection.find_one.return_value = {"_id": "id1", "payload": {"key": "value1"}} + + result = mongo_vector.get("id1") + assert result is not None + assert result.id == "id1" + assert result.payload == {"key": "value1"} + +def test_list_cols(mongo_vector_fixture): + mongo_vector, _, mock_db = mongo_vector_fixture + mock_db.list_collection_names.return_value = ["col1", "col2"] + + collections = mongo_vector.list_cols() + assert collections == ["col1", "col2"] + +def test_delete_col(mongo_vector_fixture): + mongo_vector, mock_collection, _ = mongo_vector_fixture + + mongo_vector.delete_col() + mock_collection.drop.assert_called_once() + +def test_col_info(mongo_vector_fixture): + mongo_vector, _, mock_db = mongo_vector_fixture + mock_db.command.return_value = {"count": 10, "size": 1024} + + info = mongo_vector.col_info() + mock_db.command.assert_called_once_with("collstats", "test_collection") + assert info["name"] == "test_collection" + assert info["count"] == 10 + assert info["size"] == 1024 + +def test_list(mongo_vector_fixture): + mongo_vector, mock_collection, _ = mongo_vector_fixture + mock_cursor = MagicMock() + mock_cursor.limit.return_value = [ + {"_id": "id1", "payload": {"key": "value1"}}, + {"_id": "id2", "payload": {"key": "value2"}}, + ] + mock_collection.find.return_value = mock_cursor + + query_filters = {"_id": {"$in": ["id1", "id2"]}} + results = mongo_vector.list(filters=query_filters, limit=2) + mock_collection.find.assert_called_once_with(query_filters) + mock_cursor.limit.assert_called_once_with(2) + assert len(results) == 2 + assert results[0].id == "id1" + assert results[0].payload == {"key": "value1"} + assert results[1].id == "id2" + assert results[1].payload == {"key": "value2"}