+MongoDB Vector Support (#2367)
Co-authored-by: Divya Gupta <divya.gupta@mongodb.com>
This commit is contained in:
49
docs/components/vectordbs/dbs/mongodb.mdx
Normal file
49
docs/components/vectordbs/dbs/mongodb.mdx
Normal file
@@ -0,0 +1,49 @@
|
||||
# MongoDB
|
||||
|
||||
[MongoDB](https://www.mongodb.com/) is a versatile document database that supports vector search capabilities, allowing for efficient high-dimensional similarity searches over large datasets with robust scalability and performance.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "mongodb",
|
||||
"config": {
|
||||
"db_name": "mem0-db",
|
||||
"collection_name": "mem0-collection",
|
||||
"user": "my-user",
|
||||
"password": "my-password",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
messages = [
|
||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||
{"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."},
|
||||
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||
]
|
||||
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||
```
|
||||
|
||||
## Config
|
||||
|
||||
Here are the parameters available for configuring MongoDB:
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| --- | --- | --- |
|
||||
| db_name | Name of the MongoDB database | `"mem0_db"` |
|
||||
| collection_name | Name of the MongoDB collection | `"mem0_collection"` |
|
||||
| embedding_model_dims | Dimensions of the embedding vectors | `1536` |
|
||||
| user | MongoDB user for authentication | `None` |
|
||||
| password | Password for the MongoDB user | `None` |
|
||||
| host | MongoDB host | `"localhost"` |
|
||||
| port | MongoDB port | `27017` |
|
||||
|
||||
> **Note**: `user` and `password` must either be provided together or omitted together.
|
||||
@@ -23,6 +23,7 @@ See the list of supported vector databases below.
|
||||
<Card title="Upstash Vector" href="/components/vectordbs/dbs/upstash-vector"></Card>
|
||||
<Card title="Milvus" href="/components/vectordbs/dbs/milvus"></Card>
|
||||
<Card title="Pinecone" href="/components/vectordbs/dbs/pinecone"></Card>
|
||||
<Card title="MongoDB" href="/components/vectordbs/dbs/mongodb"></Card>
|
||||
<Card title="Azure" href="/components/vectordbs/dbs/azure"></Card>
|
||||
<Card title="Redis" href="/components/vectordbs/dbs/redis"></Card>
|
||||
<Card title="Elasticsearch" href="/components/vectordbs/dbs/elasticsearch"></Card>
|
||||
|
||||
@@ -137,6 +137,7 @@
|
||||
"components/vectordbs/dbs/pgvector",
|
||||
"components/vectordbs/dbs/milvus",
|
||||
"components/vectordbs/dbs/pinecone",
|
||||
"components/vectordbs/dbs/mongodb",
|
||||
"components/vectordbs/dbs/azure",
|
||||
"components/vectordbs/dbs/redis",
|
||||
"components/vectordbs/dbs/elasticsearch",
|
||||
|
||||
42
mem0/configs/vector_stores/mongodb.py
Normal file
42
mem0/configs/vector_stores/mongodb.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Any, Dict, Optional, Callable, List
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
|
||||
class MongoVectorConfig(BaseModel):
|
||||
"""Configuration for MongoDB vector database."""
|
||||
|
||||
db_name: str = Field("mem0_db", description="Name of the MongoDB database")
|
||||
collection_name: str = Field("mem0", description="Name of the MongoDB collection")
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
|
||||
user: Optional[str] = Field(None, description="MongoDB user for authentication")
|
||||
password: Optional[str] = Field(None, description="Password for the MongoDB user")
|
||||
host: Optional[str] = Field("localhost", description="MongoDB host. Default is 'localhost'")
|
||||
port: Optional[int] = Field(27017, description="MongoDB port. Default is 27017")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_auth_and_connection(cls, values):
|
||||
user = values.get("user")
|
||||
password = values.get("password")
|
||||
if (user is None) != (password is None):
|
||||
raise ValueError("Both 'user' and 'password' must be provided together or omitted together.")
|
||||
|
||||
host = values.get("host")
|
||||
port = values.get("port")
|
||||
if host is None:
|
||||
raise ValueError("The 'host' must be provided.")
|
||||
if port is None:
|
||||
raise ValueError("The 'port' must be provided.")
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.__fields__)
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
||||
f"Please provide only the following fields: {', '.join(allowed_fields)}."
|
||||
)
|
||||
return values
|
||||
@@ -79,6 +79,7 @@ class VectorStoreFactory:
|
||||
"upstash_vector": "mem0.vector_stores.upstash_vector.UpstashVector",
|
||||
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
|
||||
"pinecone": "mem0.vector_stores.pinecone.PineconeDB",
|
||||
"mongodb": "mem0.vector_stores.mongodb.MongoDB",
|
||||
"redis": "mem0.vector_stores.redis.RedisDB",
|
||||
"elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB",
|
||||
"vertex_ai_vector_search": "mem0.vector_stores.vertex_ai_vector_search.GoogleMatchingEngine",
|
||||
|
||||
@@ -15,6 +15,7 @@ class VectorStoreConfig(BaseModel):
|
||||
"chroma": "ChromaDbConfig",
|
||||
"pgvector": "PGVectorConfig",
|
||||
"pinecone": "PineconeConfig",
|
||||
"mongodb": "MongoDBConfig",
|
||||
"milvus": "MilvusDBConfig",
|
||||
"upstash_vector": "UpstashVectorConfig",
|
||||
"azure_ai_search": "AzureAISearchConfig",
|
||||
|
||||
299
mem0/vector_stores/mongodb.py
Normal file
299
mem0/vector_stores/mongodb.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any, Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
except ImportError:
|
||||
raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.")
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
payload: Optional[dict]
|
||||
|
||||
|
||||
class MongoVector(VectorStoreBase):
|
||||
VECTOR_TYPE = "knnVector"
|
||||
SIMILARITY_METRIC = "cosine"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_name: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
mongo_uri: str
|
||||
):
|
||||
"""
|
||||
Initialize the MongoDB vector store with vector search capabilities.
|
||||
|
||||
Args:
|
||||
db_name (str): Database name
|
||||
collection_name (str): Collection name
|
||||
embedding_model_dims (int): Dimension of the embedding vector
|
||||
mongo_uri (str): MongoDB connection URI
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.db_name = db_name
|
||||
|
||||
self.client = MongoClient(
|
||||
mongo_uri
|
||||
)
|
||||
self.db = self.client[db_name]
|
||||
self.collection = self.create_col()
|
||||
|
||||
def create_col(self):
|
||||
"""Create new collection with vector search index."""
|
||||
try:
|
||||
database = self.client[self.db_name]
|
||||
collection_names = database.list_collection_names()
|
||||
if self.collection_name not in collection_names:
|
||||
logger.info(f"Collection '{self.collection_name}' does not exist. Creating it now.")
|
||||
collection = database[self.collection_name]
|
||||
# Insert and remove a placeholder document to create the collection
|
||||
collection.insert_one({"_id": 0, "placeholder": True})
|
||||
collection.delete_one({"_id": 0})
|
||||
logger.info(f"Collection '{self.collection_name}' created successfully.")
|
||||
else:
|
||||
collection = database[self.collection_name]
|
||||
|
||||
self.index_name = f"{self.collection_name}_vector_index"
|
||||
found_indexes = list(collection.list_search_indexes(name=self.index_name))
|
||||
if found_indexes:
|
||||
logger.info(f"Search index '{self.index_name}' already exists in collection '{self.collection_name}'.")
|
||||
else:
|
||||
search_index_model = SearchIndexModel(
|
||||
name=self.index_name,
|
||||
definition={
|
||||
"mappings": {
|
||||
"dynamic": False,
|
||||
"fields": {
|
||||
"embedding": {
|
||||
"type": self.VECTOR_TYPE,
|
||||
"dimensions": self.embedding_model_dims,
|
||||
"similarity": self.SIMILARITY_METRIC,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
collection.create_search_index(search_index_model)
|
||||
logger.info(
|
||||
f"Search index '{self.index_name}' created successfully for collection '{self.collection_name}'."
|
||||
)
|
||||
return collection
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error creating collection and search index: {e}")
|
||||
return None
|
||||
|
||||
def insert(
|
||||
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Insert vectors into the collection.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert.
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection '{self.collection_name}'.")
|
||||
|
||||
data = []
|
||||
for vector, payload, _id in zip(vectors, payloads or [{}] * len(vectors), ids or [None] * len(vectors)):
|
||||
document = {"_id": _id, "embedding": vector, "payload": payload}
|
||||
data.append(document)
|
||||
try:
|
||||
self.collection.insert_many(data)
|
||||
logger.info(f"Inserted {len(data)} documents into '{self.collection_name}'.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error inserting data: {e}")
|
||||
|
||||
def search(self, query: str, query_vector: List[float], limit=5, filters: Optional[Dict] = None) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors using the vector search index.
|
||||
|
||||
Args:
|
||||
query (str): Query string
|
||||
query_vector (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
|
||||
found_indexes = list(self.collection.list_search_indexes(name=self.index_name))
|
||||
if not found_indexes:
|
||||
logger.error(f"Index '{self.index_name}' does not exist.")
|
||||
return []
|
||||
|
||||
results = []
|
||||
try:
|
||||
collection = self.client[self.db_name][self.collection_name]
|
||||
pipeline = [
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"index": self.index_name,
|
||||
"limit": limit,
|
||||
"numCandidates": limit,
|
||||
"queryVector": query_vector,
|
||||
"path": "embedding",
|
||||
}
|
||||
},
|
||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||
{"$project": {"embedding": 0}},
|
||||
]
|
||||
results = list(collection.aggregate(pipeline))
|
||||
logger.info(f"Vector search completed. Found {len(results)} documents.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during vector search for query {query}: {e}")
|
||||
return []
|
||||
|
||||
output = [OutputData(id=str(doc["_id"]), score=doc.get("score"), payload=doc.get("payload")) for doc in results]
|
||||
return output
|
||||
|
||||
def delete(self, vector_id: str) -> None:
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
try:
|
||||
result = self.collection.delete_one({"_id": vector_id})
|
||||
if result.deleted_count > 0:
|
||||
logger.info(f"Deleted document with ID '{vector_id}'.")
|
||||
else:
|
||||
logger.warning(f"No document found with ID '{vector_id}' to delete.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error deleting document: {e}")
|
||||
|
||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (List[float], optional): Updated vector.
|
||||
payload (Dict, optional): Updated payload.
|
||||
"""
|
||||
update_fields = {}
|
||||
if vector is not None:
|
||||
update_fields["embedding"] = vector
|
||||
if payload is not None:
|
||||
update_fields["payload"] = payload
|
||||
|
||||
if update_fields:
|
||||
try:
|
||||
result = self.collection.update_one({"_id": vector_id}, {"$set": update_fields})
|
||||
if result.matched_count > 0:
|
||||
logger.info(f"Updated document with ID '{vector_id}'.")
|
||||
else:
|
||||
logger.warning(f"No document found with ID '{vector_id}' to update.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error updating document: {e}")
|
||||
|
||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[OutputData]: Retrieved vector or None if not found.
|
||||
"""
|
||||
try:
|
||||
doc = self.collection.find_one({"_id": vector_id})
|
||||
if doc:
|
||||
logger.info(f"Retrieved document with ID '{vector_id}'.")
|
||||
return OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload"))
|
||||
else:
|
||||
logger.warning(f"Document with ID '{vector_id}' not found.")
|
||||
return None
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error retrieving document: {e}")
|
||||
return None
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections in the database.
|
||||
|
||||
Returns:
|
||||
List[str]: List of collection names.
|
||||
"""
|
||||
try:
|
||||
collections = self.db.list_collection_names()
|
||||
logger.info(f"Listing collections in database '{self.db_name}': {collections}")
|
||||
return collections
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error listing collections: {e}")
|
||||
return []
|
||||
|
||||
def delete_col(self) -> None:
|
||||
"""Delete the collection."""
|
||||
try:
|
||||
self.collection.drop()
|
||||
logger.info(f"Deleted collection '{self.collection_name}'.")
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error deleting collection: {e}")
|
||||
|
||||
def col_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the collection.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Collection information.
|
||||
"""
|
||||
try:
|
||||
stats = self.db.command("collstats", self.collection_name)
|
||||
info = {"name": self.collection_name, "count": stats.get("count"), "size": stats.get("size")}
|
||||
logger.info(f"Collection info: {info}")
|
||||
return info
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error getting collection info: {e}")
|
||||
return {}
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
"""
|
||||
List vectors in the collection.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply to the list.
|
||||
limit (int, optional): Number of vectors to return.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
try:
|
||||
query = filters or {}
|
||||
cursor = self.collection.find(query).limit(limit)
|
||||
results = [OutputData(id=str(doc["_id"]), score=None, payload=doc.get("payload")) for doc in cursor]
|
||||
logger.info(f"Retrieved {len(results)} documents from collection '{self.collection_name}'.")
|
||||
return results
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error listing documents: {e}")
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.collection = self.create_col(self.collection_name)
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Close the database connection when the object is deleted."""
|
||||
if hasattr(self, "client"):
|
||||
self.client.close()
|
||||
logger.info("MongoClient connection closed.")
|
||||
176
tests/vector_stores/test_mongodb.py
Normal file
176
tests/vector_stores/test_mongodb.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import time
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from mem0.vector_stores.mongodb import MongoVector
|
||||
from pymongo.operations import SearchIndexModel
|
||||
|
||||
@pytest.fixture
|
||||
@patch("mem0.vector_stores.mongodb.MongoClient")
|
||||
def mongo_vector_fixture(mock_mongo_client):
|
||||
mock_client = mock_mongo_client.return_value
|
||||
mock_db = mock_client["test_db"]
|
||||
mock_collection = mock_db["test_collection"]
|
||||
mock_collection.list_search_indexes.return_value = []
|
||||
mock_collection.aggregate.return_value = []
|
||||
mock_collection.find_one.return_value = None
|
||||
mock_collection.find.return_value = []
|
||||
mock_db.list_collection_names.return_value = []
|
||||
|
||||
mongo_vector = MongoVector(
|
||||
db_name="test_db",
|
||||
collection_name="test_collection",
|
||||
embedding_model_dims=1536,
|
||||
user="username",
|
||||
password="password",
|
||||
)
|
||||
return mongo_vector, mock_collection, mock_db
|
||||
|
||||
def test_initalize_create_col(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, mock_db = mongo_vector_fixture
|
||||
assert mongo_vector.collection_name == "test_collection"
|
||||
assert mongo_vector.embedding_model_dims == 1536
|
||||
assert mongo_vector.db_name == "test_db"
|
||||
|
||||
# Verify create_col being called
|
||||
mock_db.list_collection_names.assert_called_once()
|
||||
mock_collection.insert_one.assert_called_once_with({"_id": 0, "placeholder": True})
|
||||
mock_collection.delete_one.assert_called_once_with({"_id": 0})
|
||||
assert mongo_vector.index_name == "test_collection_vector_index"
|
||||
mock_collection.list_search_indexes.assert_called_once_with(name="test_collection_vector_index")
|
||||
mock_collection.create_search_index.assert_called_once()
|
||||
args, _ = mock_collection.create_search_index.call_args
|
||||
search_index_model = args[0].document
|
||||
assert search_index_model == {
|
||||
"name": "test_collection_vector_index",
|
||||
"definition": {
|
||||
"mappings": {
|
||||
"dynamic": False,
|
||||
"fields": {
|
||||
"embedding": {
|
||||
"type": "knnVector",
|
||||
"d": 1536,
|
||||
"similarity": "cosine",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
assert mongo_vector.collection == mock_collection
|
||||
|
||||
def test_insert(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
vectors = [[0.1] * 1536, [0.2] * 1536]
|
||||
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||
ids = ["id1", "id2"]
|
||||
|
||||
mongo_vector.insert(vectors, payloads, ids)
|
||||
expected_records=[
|
||||
({"_id": ids[0], "embedding": vectors[0], "payload": payloads[0]}),
|
||||
({"_id": ids[1], "embedding": vectors[1], "payload": payloads[1]})
|
||||
]
|
||||
mock_collection.insert_many.assert_called_once_with(expected_records)
|
||||
|
||||
def test_search(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
query_vector = [0.1] * 1536
|
||||
mock_collection.aggregate.return_value = [
|
||||
{"_id": "id1", "score": 0.9, "payload": {"key": "value1"}},
|
||||
{"_id": "id2", "score": 0.8, "payload": {"key": "value2"}},
|
||||
]
|
||||
mock_collection.list_search_indexes.return_value = ["test_collection_vector_index"]
|
||||
|
||||
results = mongo_vector.search("query_str", query_vector, limit=2)
|
||||
mock_collection.list_search_indexes.assert_called_with(name="test_collection_vector_index")
|
||||
mock_collection.aggregate.assert_called_once_with([
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"index": "test_collection_vector_index",
|
||||
"limit": 2,
|
||||
"numCandidates": 2,
|
||||
"queryVector": query_vector,
|
||||
"path": "embedding",
|
||||
},
|
||||
},
|
||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||
{"$project": {"embedding": 0}},
|
||||
])
|
||||
assert len(results) == 2
|
||||
assert results[0].id == "id1"
|
||||
assert results[0].score == 0.9
|
||||
assert results[1].id == "id2"
|
||||
assert results[1].score == 0.8
|
||||
|
||||
def test_delete(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
mock_delete_result = MagicMock()
|
||||
mock_delete_result.deleted_count = 1
|
||||
mock_collection.delete_one.return_value = mock_delete_result
|
||||
|
||||
mongo_vector.delete("id1")
|
||||
mock_collection.delete_one.assert_called_with({"_id": "id1"})
|
||||
|
||||
def test_update(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
mock_update_result = MagicMock()
|
||||
mock_update_result.matched_count = 1
|
||||
mock_collection.update_one.return_value = mock_update_result
|
||||
idValue = "id1"
|
||||
vectorValue = [0.2] * 1536
|
||||
payloadValue = {"key": "updated"}
|
||||
|
||||
mongo_vector.update(idValue, vector=vectorValue, payload=payloadValue)
|
||||
mock_collection.update_one.assert_called_once_with(
|
||||
{"_id": idValue},
|
||||
{"$set": {"embedding": vectorValue, "payload": payloadValue}},
|
||||
)
|
||||
|
||||
def test_get(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
mock_collection.find_one.return_value = {"_id": "id1", "payload": {"key": "value1"}}
|
||||
|
||||
result = mongo_vector.get("id1")
|
||||
assert result is not None
|
||||
assert result.id == "id1"
|
||||
assert result.payload == {"key": "value1"}
|
||||
|
||||
def test_list_cols(mongo_vector_fixture):
|
||||
mongo_vector, _, mock_db = mongo_vector_fixture
|
||||
mock_db.list_collection_names.return_value = ["col1", "col2"]
|
||||
|
||||
collections = mongo_vector.list_cols()
|
||||
assert collections == ["col1", "col2"]
|
||||
|
||||
def test_delete_col(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
|
||||
mongo_vector.delete_col()
|
||||
mock_collection.drop.assert_called_once()
|
||||
|
||||
def test_col_info(mongo_vector_fixture):
|
||||
mongo_vector, _, mock_db = mongo_vector_fixture
|
||||
mock_db.command.return_value = {"count": 10, "size": 1024}
|
||||
|
||||
info = mongo_vector.col_info()
|
||||
mock_db.command.assert_called_once_with("collstats", "test_collection")
|
||||
assert info["name"] == "test_collection"
|
||||
assert info["count"] == 10
|
||||
assert info["size"] == 1024
|
||||
|
||||
def test_list(mongo_vector_fixture):
|
||||
mongo_vector, mock_collection, _ = mongo_vector_fixture
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.limit.return_value = [
|
||||
{"_id": "id1", "payload": {"key": "value1"}},
|
||||
{"_id": "id2", "payload": {"key": "value2"}},
|
||||
]
|
||||
mock_collection.find.return_value = mock_cursor
|
||||
|
||||
query_filters = {"_id": {"$in": ["id1", "id2"]}}
|
||||
results = mongo_vector.list(filters=query_filters, limit=2)
|
||||
mock_collection.find.assert_called_once_with(query_filters)
|
||||
mock_cursor.limit.assert_called_once_with(2)
|
||||
assert len(results) == 2
|
||||
assert results[0].id == "id1"
|
||||
assert results[0].payload == {"key": "value1"}
|
||||
assert results[1].id == "id2"
|
||||
assert results[1].payload == {"key": "value2"}
|
||||
Reference in New Issue
Block a user