+MongoDB Vector Support (#2367)

Co-authored-by: Divya Gupta <divya.gupta@mongodb.com>
This commit is contained in:
Fabian Valle
2025-06-14 08:27:06 -04:00
committed by GitHub
parent 7c0c4a03c4
commit a0cd4065d9
9 changed files with 571 additions and 1 deletions

View File

@@ -0,0 +1,49 @@
# MongoDB
[MongoDB](https://www.mongodb.com/) is a versatile document database that supports vector search capabilities, allowing for efficient high-dimensional similarity searches over large datasets with robust scalability and performance.
## Usage
```python
import os
from mem0 import Memory
os.environ["OPENAI_API_KEY"] = "sk-xx"
config = {
"vector_store": {
"provider": "mongodb",
"config": {
"db_name": "mem0-db",
"collection_name": "mem0-collection",
"user": "my-user",
"password": "my-password",
}
}
}
m = Memory.from_config(config)
messages = [
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
{"role": "user", "content": "Im not a big fan of thriller movies but I love sci-fi movies."},
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
]
m.add(messages, user_id="alice", metadata={"category": "movies"})
```
## Config
Here are the parameters available for configuring MongoDB:
| Parameter | Description | Default Value |
| --- | --- | --- |
| db_name | Name of the MongoDB database | `"mem0_db"` |
| collection_name | Name of the MongoDB collection | `"mem0_collection"` |
| embedding_model_dims | Dimensions of the embedding vectors | `1536` |
| user | MongoDB user for authentication | `None` |
| password | Password for the MongoDB user | `None` |
| host | MongoDB host | `"localhost"` |
| port | MongoDB port | `27017` |
> **Note**: `user` and `password` must either be provided together or omitted together.

View File

@@ -23,6 +23,7 @@ See the list of supported vector databases below.
<Card title="Upstash Vector" href="/components/vectordbs/dbs/upstash-vector"></Card>
<Card title="Milvus" href="/components/vectordbs/dbs/milvus"></Card>
<Card title="Pinecone" href="/components/vectordbs/dbs/pinecone"></Card>
<Card title="MongoDB" href="/components/vectordbs/dbs/mongodb"></Card>
<Card title="Azure" href="/components/vectordbs/dbs/azure"></Card>
<Card title="Redis" href="/components/vectordbs/dbs/redis"></Card>
<Card title="Elasticsearch" href="/components/vectordbs/dbs/elasticsearch"></Card>

View File

@@ -137,6 +137,7 @@
"components/vectordbs/dbs/pgvector",
"components/vectordbs/dbs/milvus",
"components/vectordbs/dbs/pinecone",
"components/vectordbs/dbs/mongodb",
"components/vectordbs/dbs/azure",
"components/vectordbs/dbs/redis",
"components/vectordbs/dbs/elasticsearch",

View File

@@ -0,0 +1,42 @@
from typing import Any, Dict, Optional, Callable, List
from pydantic import BaseModel, Field, root_validator
class MongoVectorConfig(BaseModel):
"""Configuration for MongoDB vector database."""
db_name: str = Field("mem0_db", description="Name of the MongoDB database")
collection_name: str = Field("mem0", description="Name of the MongoDB collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
user: Optional[str] = Field(None, description="MongoDB user for authentication")
password: Optional[str] = Field(None, description="Password for the MongoDB user")
host: Optional[str] = Field("localhost", description="MongoDB host. Default is 'localhost'")
port: Optional[int] = Field(27017, description="MongoDB port. Default is 27017")
@root_validator(pre=True)
def check_auth_and_connection(cls, values):
user = values.get("user")
password = values.get("password")
if (user is None) != (password is None):
raise ValueError("Both 'user' and 'password' must be provided together or omitted together.")
host = values.get("host")
port = values.get("port")
if host is None:
raise ValueError("The 'host' must be provided.")
if port is None:
raise ValueError("The 'port' must be provided.")
return values
@root_validator(pre=True)
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.__fields__)
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Please provide only the following fields: {', '.join(allowed_fields)}."
)
return values

View File

@@ -79,6 +79,7 @@ class VectorStoreFactory:
"upstash_vector": "mem0.vector_stores.upstash_vector.UpstashVector",
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
"pinecone": "mem0.vector_stores.pinecone.PineconeDB",
"mongodb": "mem0.vector_stores.mongodb.MongoDB",
"redis": "mem0.vector_stores.redis.RedisDB",
"elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB",
"vertex_ai_vector_search": "mem0.vector_stores.vertex_ai_vector_search.GoogleMatchingEngine",

View File

@@ -15,6 +15,7 @@ class VectorStoreConfig(BaseModel):
"chroma": "ChromaDbConfig",
"pgvector": "PGVectorConfig",
"pinecone": "PineconeConfig",
"mongodb": "MongoDBConfig",
"milvus": "MilvusDBConfig",
"upstash_vector": "UpstashVectorConfig",
"azure_ai_search": "AzureAISearchConfig",

View File

@@ -0,0 +1,299 @@
import logging
from typing import List, Optional, Dict, Any, Callable
from pydantic import BaseModel
try:
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
except ImportError:
raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.")
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[dict]
class MongoVector(VectorStoreBase):
VECTOR_TYPE = "knnVector"
SIMILARITY_METRIC = "cosine"
def __init__(
self,
db_name: str,
collection_name: str,
embedding_model_dims: int,
mongo_uri: str
):
"""
Initialize the MongoDB vector store with vector search capabilities.
Args:
db_name (str): Database name
collection_name (str): Collection name
embedding_model_dims (int): Dimension of the embedding vector
mongo_uri (str): MongoDB connection URI
"""
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.db_name = db_name
self.client = MongoClient(
mongo_uri
)
self.db = self.client[db_name]
self.collection = self.create_col()
def create_col(self):
"""Create new collection with vector search index."""
try:
database = self.client[self.db_name]
collection_names = database.list_collection_names()
if self.collection_name not in collection_names:
logger.info(f"Collection '{self.collection_name}' does not exist. Creating it now.")
collection = database[self.collection_name]
# Insert and remove a placeholder document to create the collection
collection.insert_one({"_id": 0, "placeholder": True})
collection.delete_one({"_id": 0})
logger.info(f"Collection '{self.collection_name}' created successfully.")
else:
collection = database[self.collection_name]
self.index_name = f"{self.collection_name}_vector_index"
found_indexes = list(collection.list_search_indexes(name=self.index_name))
if found_indexes:
logger.info(f"Search index '{self.index_name}' already exists in collection '{self.collection_name}'.")
else:
search_index_model = SearchIndexModel(
name=self.index_name,
definition={
"mappings": {
"dynamic": False,
"fields": {
"embedding": {
"type": self.VECTOR_TYPE,
"dimensions": self.embedding_model_dims,
"similarity": self.SIMILARITY_METRIC,
}
},
}
},
)
collection.create_search_index(search_index_model)
logger.info(
f"Search index '{self.index_name}' created successfully for collection '{self.collection_name}'."
)
return collection
except PyMongoError as e:
logger.error(f"Error creating collection and search index: {e}")
return None
def insert(
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
) -> None:
"""
Insert vectors into the collection.
Args:
vectors (List[List[float]]): List of vectors to insert.
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
logger.info(f"Inserting {len(vectors)} vectors into collection '{self.collection_name}'.")
data = []
for vector, payload, _id in zip(vectors, payloads or [{}] * len(vectors), ids or [None] * len(vectors)):
document = {"_id": _id, "embedding": vector, "payload": payload}
data.append(document)
try:
self.collection.insert_many(data)
logger.info(f"Inserted {len(data)} documents into '{self.collection_name}'.")
except PyMongoError as e:
logger.error(f"Error inserting data: {e}")
def search(self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]:
"""
Search for similar vectors using the vector search index.
Args:
query (str): Query string
query_vector (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search.
Returns:
List[OutputData]: Search results.
"""
found_indexes = list(self.collection.list_search_indexes(name=self.index_name))
if not found_indexes:
logger.error(f"Index '{self.index_name}' does not exist.")
return []
results = []
try:
collection = self.client[self.db_name][self.collection_name]
pipeline = [
{
"$vectorSearch": {
"index": self.index_name,
"limit": limit,
"numCandidates": limit,
"queryVector": query_vector,
"path": "embedding",
}
},
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
{"$project": {"embedding": 0}},
]
results = list(collection.aggregate(pipeline))
logger.info(f"Vector search completed. Found {len(results)} documents.")
except Exception as e:
logger.error(f"Error during vector search for query {query}: {e}")
return []
output = [OutputData(id=str(doc["_id"]), score=doc.get("score"), payload=doc.get("payload")) for doc in results]
return output
def delete(self, vector_id: str) -> None:
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
try:
result = self.collection.delete_one({"_id": vector_id})
if result.deleted_count > 0:
logger.info(f"Deleted document with ID '{vector_id}'.")
else:
logger.warning(f"No document found with ID '{vector_id}' to delete.")
except PyMongoError as e:
logger.error(f"Error deleting document: {e}")
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
update_fields = {}
if vector is not None:
update_fields["embedding"] = vector
if payload is not None:
update_fields["payload"] = payload
if update_fields:
try:
result = self.collection.update_one({"_id": vector_id}, {"$set": update_fields})
if result.matched_count > 0:
logger.info(f"Updated document with ID '{vector_id}'.")
else:
logger.warning(f"No document found with ID '{vector_id}' to update.")
except PyMongoError as e:
logger.error(f"Error updating document: {e}")
def get(self, vector_id: str) -> Optional[OutputData]:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
Optional[OutputData]: Retrieved vector or None if not found.
"""
try:
doc = self.collection.find_one({"_id": vector_id})
if doc:
logger.info(f"Retrieved document with ID '{vector_id}'.")
return OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload"))
else:
logger.warning(f"Document with ID '{vector_id}' not found.")
return None
except PyMongoError as e:
logger.error(f"Error retrieving document: {e}")
return None
def list_cols(self) -> List[str]:
"""
List all collections in the database.
Returns:
List[str]: List of collection names.
"""
try:
collections = self.db.list_collection_names()
logger.info(f"Listing collections in database '{self.db_name}': {collections}")
return collections
except PyMongoError as e:
logger.error(f"Error listing collections: {e}")
return []
def delete_col(self) -> None:
"""Delete the collection."""
try:
self.collection.drop()
logger.info(f"Deleted collection '{self.collection_name}'.")
except PyMongoError as e:
logger.error(f"Error deleting collection: {e}")
def col_info(self) -> Dict[str, Any]:
"""
Get information about the collection.
Returns:
Dict[str, Any]: Collection information.
"""
try:
stats = self.db.command("collstats", self.collection_name)
info = {"name": self.collection_name, "count": stats.get("count"), "size": stats.get("size")}
logger.info(f"Collection info: {info}")
return info
except PyMongoError as e:
logger.error(f"Error getting collection info: {e}")
return {}
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
"""
List vectors in the collection.
Args:
filters (Dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return.
Returns:
List[OutputData]: List of vectors.
"""
try:
query = filters or {}
cursor = self.collection.find(query).limit(limit)
results = [OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload")) for doc in cursor]
logger.info(f"Retrieved {len(results)} documents from collection '{self.collection_name}'.")
return results
except PyMongoError as e:
logger.error(f"Error listing documents: {e}")
return []
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.collection = self.create_col(self.collection_name)
def __del__(self) -> None:
"""Close the database connection when the object is deleted."""
if hasattr(self, "client"):
self.client.close()
logger.info("MongoClient connection closed.")

View File

@@ -0,0 +1,176 @@
import time
import pytest
from unittest.mock import MagicMock, patch
from mem0.vector_stores.mongodb import MongoVector
from pymongo.operations import SearchIndexModel
@pytest.fixture
@patch("mem0.vector_stores.mongodb.MongoClient")
def mongo_vector_fixture(mock_mongo_client):
mock_client = mock_mongo_client.return_value
mock_db = mock_client["test_db"]
mock_collection = mock_db["test_collection"]
mock_collection.list_search_indexes.return_value = []
mock_collection.aggregate.return_value = []
mock_collection.find_one.return_value = None
mock_collection.find.return_value = []
mock_db.list_collection_names.return_value = []
mongo_vector = MongoVector(
db_name="test_db",
collection_name="test_collection",
embedding_model_dims=1536,
user="username",
password="password",
)
return mongo_vector, mock_collection, mock_db
def test_initalize_create_col(mongo_vector_fixture):
mongo_vector, mock_collection, mock_db = mongo_vector_fixture
assert mongo_vector.collection_name == "test_collection"
assert mongo_vector.embedding_model_dims == 1536
assert mongo_vector.db_name == "test_db"
# Verify create_col being called
mock_db.list_collection_names.assert_called_once()
mock_collection.insert_one.assert_called_once_with({"_id": 0, "placeholder": True})
mock_collection.delete_one.assert_called_once_with({"_id": 0})
assert mongo_vector.index_name == "test_collection_vector_index"
mock_collection.list_search_indexes.assert_called_once_with(name="test_collection_vector_index")
mock_collection.create_search_index.assert_called_once()
args, _ = mock_collection.create_search_index.call_args
search_index_model = args[0].document
assert search_index_model == {
"name": "test_collection_vector_index",
"definition": {
"mappings": {
"dynamic": False,
"fields": {
"embedding": {
"type": "knnVector",
"d": 1536,
"similarity": "cosine",
}
}
}
}
}
assert mongo_vector.collection == mock_collection
def test_insert(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
vectors = [[0.1] * 1536, [0.2] * 1536]
payloads = [{"name": "vector1"}, {"name": "vector2"}]
ids = ["id1", "id2"]
mongo_vector.insert(vectors, payloads, ids)
expected_records=[
({"_id": ids[0], "embedding": vectors[0], "payload": payloads[0]}),
({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]})
]
mock_collection.insert_many.assert_called_once_with(expected_records)
def test_search(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
query_vector = [0.1] * 1536
mock_collection.aggregate.return_value = [
{"_id": "id1", "score": 0.9, "payload": {"key": "value1"}},
{"_id": "id2", "score": 0.8, "payload": {"key": "value2"}},
]
mock_collection.list_search_indexes.return_value = ["test_collection_vector_index"]
results = mongo_vector.search("query_str", query_vector, limit=2)
mock_collection.list_search_indexes.assert_called_with(name="test_collection_vector_index")
mock_collection.aggregate.assert_called_once_with([
{
"$vectorSearch": {
"index": "test_collection_vector_index",
"limit": 2,
"numCandidates": 2,
"queryVector": query_vector,
"path": "embedding",
},
},
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
{"$project": {"embedding": 0}},
])
assert len(results) == 2
assert results[0].id == "id1"
assert results[0].score == 0.9
assert results[1].id == "id2"
assert results[1].score == 0.8
def test_delete(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_delete_result = MagicMock()
mock_delete_result.deleted_count = 1
mock_collection.delete_one.return_value = mock_delete_result
mongo_vector.delete("id1")
mock_collection.delete_one.assert_called_with({"_id": "id1"})
def test_update(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_update_result = MagicMock()
mock_update_result.matched_count = 1
mock_collection.update_one.return_value = mock_update_result
idValue = "id1"
vectorValue = [0.2] * 1536
payloadValue = {"key": "updated"}
mongo_vector.update(idValue, vector=vectorValue, payload=payloadValue)
mock_collection.update_one.assert_called_once_with(
{"_id": idValue},
{"$set": {"embedding": vectorValue, "payload": payloadValue}},
)
def test_get(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_collection.find_one.return_value = {"_id": "id1", "payload": {"key": "value1"}}
result = mongo_vector.get("id1")
assert result is not None
assert result.id == "id1"
assert result.payload == {"key": "value1"}
def test_list_cols(mongo_vector_fixture):
mongo_vector, _, mock_db = mongo_vector_fixture
mock_db.list_collection_names.return_value = ["col1", "col2"]
collections = mongo_vector.list_cols()
assert collections == ["col1", "col2"]
def test_delete_col(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
mongo_vector.delete_col()
mock_collection.drop.assert_called_once()
def test_col_info(mongo_vector_fixture):
mongo_vector, _, mock_db = mongo_vector_fixture
mock_db.command.return_value = {"count": 10, "size": 1024}
info = mongo_vector.col_info()
mock_db.command.assert_called_once_with("collstats", "test_collection")
assert info["name"] == "test_collection"
assert info["count"] == 10
assert info["size"] == 1024
def test_list(mongo_vector_fixture):
mongo_vector, mock_collection, _ = mongo_vector_fixture
mock_cursor = MagicMock()
mock_cursor.limit.return_value = [
{"_id": "id1", "payload": {"key": "value1"}},
{"_id": "id2", "payload": {"key": "value2"}},
]
mock_collection.find.return_value = mock_cursor
query_filters = {"_id": {"$in": ["id1", "id2"]}}
results = mongo_vector.list(filters=query_filters, limit=2)
mock_collection.find.assert_called_once_with(query_filters)
mock_cursor.limit.assert_called_once_with(2)
assert len(results) == 2
assert results[0].id == "id1"
assert results[0].payload == {"key": "value1"}
assert results[1].id == "id2"
assert results[1].payload == {"key": "value2"}