MongoDB Vector Store misaligned strings and classes (#3064)
This commit is contained in:
@@ -16,8 +16,7 @@ config = {
|
||||
"config": {
|
||||
"db_name": "mem0-db",
|
||||
"collection_name": "mem0-collection",
|
||||
"user": "my-user",
|
||||
"password": "my-password",
|
||||
"mongo_uri":"mongodb://username:password@localhost:27017"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -41,9 +40,6 @@ Here are the parameters available for configuring MongoDB:
|
||||
| db_name | Name of the MongoDB database | `"mem0_db"` |
|
||||
| collection_name | Name of the MongoDB collection | `"mem0_collection"` |
|
||||
| embedding_model_dims | Dimensions of the embedding vectors | `1536` |
|
||||
| user | MongoDB user for authentication | `None` |
|
||||
| password | Password for the MongoDB user | `None` |
|
||||
| host | MongoDB host | `"localhost"` |
|
||||
| port | MongoDB port | `27017` |
|
||||
| mongo_uri | The mongo URI connection string | mongodb://username:password@localhost:27017 |
|
||||
|
||||
> **Note**: `user` and `password` must either be provided together or omitted together.
|
||||
> **Note**: If Mongo_uri is not provided it will default to mongodb://username:password@localhost:27017.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Dict, Optional, Callable, List
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class MongoDBConfig(BaseModel):
|
||||
@@ -9,29 +9,12 @@ class MongoDBConfig(BaseModel):
|
||||
db_name: str = Field("mem0_db", description="Name of the MongoDB database")
|
||||
collection_name: str = Field("mem0", description="Name of the MongoDB collection")
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
|
||||
user: Optional[str] = Field(None, description="MongoDB user for authentication")
|
||||
password: Optional[str] = Field(None, description="Password for the MongoDB user")
|
||||
host: Optional[str] = Field("localhost", description="MongoDB host. Default is 'localhost'")
|
||||
port: Optional[int] = Field(27017, description="MongoDB port. Default is 27017")
|
||||
mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_auth_and_connection(cls, values):
|
||||
user = values.get("user")
|
||||
password = values.get("password")
|
||||
if (user is None) != (password is None):
|
||||
raise ValueError("Both 'user' and 'password' must be provided together or omitted together.")
|
||||
|
||||
host = values.get("host")
|
||||
port = values.get("port")
|
||||
if host is None:
|
||||
raise ValueError("The 'host' must be provided.")
|
||||
if port is None:
|
||||
raise ValueError("The 'port' must be provided.")
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.__fields__)
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
|
||||
@@ -22,7 +22,7 @@ class OutputData(BaseModel):
|
||||
payload: Optional[dict]
|
||||
|
||||
|
||||
class MongoVector(VectorStoreBase):
|
||||
class MongoDB(VectorStoreBase):
|
||||
VECTOR_TYPE = "knnVector"
|
||||
SIMILARITY_METRIC = "cosine"
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import time
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from mem0.vector_stores.mongodb import MongoVector
|
||||
from mem0.vector_stores.mongodb import MongoDB
|
||||
from pymongo.operations import SearchIndexModel
|
||||
|
||||
@pytest.fixture
|
||||
@@ -16,12 +15,11 @@ def mongo_vector_fixture(mock_mongo_client):
|
||||
mock_collection.find.return_value = []
|
||||
mock_db.list_collection_names.return_value = []
|
||||
|
||||
mongo_vector = MongoVector(
|
||||
mongo_vector = MongoDB(
|
||||
db_name="test_db",
|
||||
collection_name="test_collection",
|
||||
embedding_model_dims=1536,
|
||||
user="username",
|
||||
password="password",
|
||||
mongo_uri="mongodb://username:password@localhost:27017"
|
||||
)
|
||||
return mongo_vector, mock_collection, mock_db
|
||||
|
||||
@@ -48,7 +46,7 @@ def test_initalize_create_col(mongo_vector_fixture):
|
||||
"fields": {
|
||||
"embedding": {
|
||||
"type": "knnVector",
|
||||
"d": 1536,
|
||||
"dimensions": 1536,
|
||||
"similarity": "cosine",
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user