MongoDB Vector Store misaligned strings and classes (#3064)

This commit is contained in:
Kade Shockey
2025-06-30 23:42:28 -06:00
committed by GitHub
parent 5a1083b709
commit b79bfb7c1e
4 changed files with 14 additions and 37 deletions

View File

@@ -16,8 +16,7 @@ config = {
"config": { "config": {
"db_name": "mem0-db", "db_name": "mem0-db",
"collection_name": "mem0-collection", "collection_name": "mem0-collection",
"user": "my-user", "mongo_uri":"mongodb://username:password@localhost:27017"
"password": "my-password",
} }
} }
} }
@@ -41,9 +40,6 @@ Here are the parameters available for configuring MongoDB:
| db_name | Name of the MongoDB database | `"mem0_db"` | | db_name | Name of the MongoDB database | `"mem0_db"` |
| collection_name | Name of the MongoDB collection | `"mem0_collection"` | | collection_name | Name of the MongoDB collection | `"mem0_collection"` |
| embedding_model_dims | Dimensions of the embedding vectors | `1536` | | embedding_model_dims | Dimensions of the embedding vectors | `1536` |
| user | MongoDB user for authentication | `None` | | mongo_uri | The mongo URI connection string | mongodb://username:password@localhost:27017 |
| password | Password for the MongoDB user | `None` |
| host | MongoDB host | `"localhost"` |
| port | MongoDB port | `27017` |
> **Note**: `user` and `password` must either be provided together or omitted together. > **Note**: If Mongo_uri is not provided it will default to mongodb://username:password@localhost:27017.

View File

@@ -1,6 +1,6 @@
from typing import Any, Dict, Optional, Callable, List from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, root_validator from pydantic import BaseModel, Field, model_validator
class MongoDBConfig(BaseModel): class MongoDBConfig(BaseModel):
@@ -9,29 +9,12 @@ class MongoDBConfig(BaseModel):
db_name: str = Field("mem0_db", description="Name of the MongoDB database") db_name: str = Field("mem0_db", description="Name of the MongoDB database")
collection_name: str = Field("mem0", description="Name of the MongoDB collection") collection_name: str = Field("mem0", description="Name of the MongoDB collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors") embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
user: Optional[str] = Field(None, description="MongoDB user for authentication") mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
password: Optional[str] = Field(None, description="Password for the MongoDB user")
host: Optional[str] = Field("localhost", description="MongoDB host. Default is 'localhost'")
port: Optional[int] = Field(27017, description="MongoDB port. Default is 27017")
@root_validator(pre=True) @model_validator(mode='before')
def check_auth_and_connection(cls, values): @classmethod
user = values.get("user")
password = values.get("password")
if (user is None) != (password is None):
raise ValueError("Both 'user' and 'password' must be provided together or omitted together.")
host = values.get("host")
port = values.get("port")
if host is None:
raise ValueError("The 'host' must be provided.")
if port is None:
raise ValueError("The 'port' must be provided.")
return values
@root_validator(pre=True)
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.__fields__) allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys()) input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields extra_fields = input_fields - allowed_fields
if extra_fields: if extra_fields:

View File

@@ -22,7 +22,7 @@ class OutputData(BaseModel):
payload: Optional[dict] payload: Optional[dict]
class MongoVector(VectorStoreBase): class MongoDB(VectorStoreBase):
VECTOR_TYPE = "knnVector" VECTOR_TYPE = "knnVector"
SIMILARITY_METRIC = "cosine" SIMILARITY_METRIC = "cosine"

View File

@@ -1,7 +1,6 @@
import time
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from mem0.vector_stores.mongodb import MongoVector from mem0.vector_stores.mongodb import MongoDB
from pymongo.operations import SearchIndexModel from pymongo.operations import SearchIndexModel
@pytest.fixture @pytest.fixture
@@ -16,12 +15,11 @@ def mongo_vector_fixture(mock_mongo_client):
mock_collection.find.return_value = [] mock_collection.find.return_value = []
mock_db.list_collection_names.return_value = [] mock_db.list_collection_names.return_value = []
mongo_vector = MongoVector( mongo_vector = MongoDB(
db_name="test_db", db_name="test_db",
collection_name="test_collection", collection_name="test_collection",
embedding_model_dims=1536, embedding_model_dims=1536,
user="username", mongo_uri="mongodb://username:password@localhost:27017"
password="password",
) )
return mongo_vector, mock_collection, mock_db return mongo_vector, mock_collection, mock_db
@@ -48,7 +46,7 @@ def test_initalize_create_col(mongo_vector_fixture):
"fields": { "fields": {
"embedding": { "embedding": {
"type": "knnVector", "type": "knnVector",
"d": 1536, "dimensions": 1536,
"similarity": "cosine", "similarity": "cosine",
} }
} }