MongoDB Vector Store misaligned strings and classes (#3064)
This commit is contained in:
@@ -16,8 +16,7 @@ config = {
|
|||||||
"config": {
|
"config": {
|
||||||
"db_name": "mem0-db",
|
"db_name": "mem0-db",
|
||||||
"collection_name": "mem0-collection",
|
"collection_name": "mem0-collection",
|
||||||
"user": "my-user",
|
"mongo_uri":"mongodb://username:password@localhost:27017"
|
||||||
"password": "my-password",
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -41,9 +40,6 @@ Here are the parameters available for configuring MongoDB:
|
|||||||
| db_name | Name of the MongoDB database | `"mem0_db"` |
|
| db_name | Name of the MongoDB database | `"mem0_db"` |
|
||||||
| collection_name | Name of the MongoDB collection | `"mem0_collection"` |
|
| collection_name | Name of the MongoDB collection | `"mem0_collection"` |
|
||||||
| embedding_model_dims | Dimensions of the embedding vectors | `1536` |
|
| embedding_model_dims | Dimensions of the embedding vectors | `1536` |
|
||||||
| user | MongoDB user for authentication | `None` |
|
| mongo_uri | The mongo URI connection string | mongodb://username:password@localhost:27017 |
|
||||||
| password | Password for the MongoDB user | `None` |
|
|
||||||
| host | MongoDB host | `"localhost"` |
|
|
||||||
| port | MongoDB port | `27017` |
|
|
||||||
|
|
||||||
> **Note**: `user` and `password` must either be provided together or omitted together.
|
> **Note**: If Mongo_uri is not provided it will default to mongodb://username:password@localhost:27017.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, Dict, Optional, Callable, List
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
class MongoDBConfig(BaseModel):
|
class MongoDBConfig(BaseModel):
|
||||||
@@ -9,29 +9,12 @@ class MongoDBConfig(BaseModel):
|
|||||||
db_name: str = Field("mem0_db", description="Name of the MongoDB database")
|
db_name: str = Field("mem0_db", description="Name of the MongoDB database")
|
||||||
collection_name: str = Field("mem0", description="Name of the MongoDB collection")
|
collection_name: str = Field("mem0", description="Name of the MongoDB collection")
|
||||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
|
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding vectors")
|
||||||
user: Optional[str] = Field(None, description="MongoDB user for authentication")
|
mongo_uri: str = Field("mongodb://localhost:27017", description="MongoDB URI. Default is mongodb://localhost:27017")
|
||||||
password: Optional[str] = Field(None, description="Password for the MongoDB user")
|
|
||||||
host: Optional[str] = Field("localhost", description="MongoDB host. Default is 'localhost'")
|
|
||||||
port: Optional[int] = Field(27017, description="MongoDB port. Default is 27017")
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode='before')
|
||||||
def check_auth_and_connection(cls, values):
|
@classmethod
|
||||||
user = values.get("user")
|
|
||||||
password = values.get("password")
|
|
||||||
if (user is None) != (password is None):
|
|
||||||
raise ValueError("Both 'user' and 'password' must be provided together or omitted together.")
|
|
||||||
|
|
||||||
host = values.get("host")
|
|
||||||
port = values.get("port")
|
|
||||||
if host is None:
|
|
||||||
raise ValueError("The 'host' must be provided.")
|
|
||||||
if port is None:
|
|
||||||
raise ValueError("The 'port' must be provided.")
|
|
||||||
return values
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
allowed_fields = set(cls.__fields__)
|
allowed_fields = set(cls.model_fields.keys())
|
||||||
input_fields = set(values.keys())
|
input_fields = set(values.keys())
|
||||||
extra_fields = input_fields - allowed_fields
|
extra_fields = input_fields - allowed_fields
|
||||||
if extra_fields:
|
if extra_fields:
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class OutputData(BaseModel):
|
|||||||
payload: Optional[dict]
|
payload: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
class MongoVector(VectorStoreBase):
|
class MongoDB(VectorStoreBase):
|
||||||
VECTOR_TYPE = "knnVector"
|
VECTOR_TYPE = "knnVector"
|
||||||
SIMILARITY_METRIC = "cosine"
|
SIMILARITY_METRIC = "cosine"
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import time
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from mem0.vector_stores.mongodb import MongoVector
|
from mem0.vector_stores.mongodb import MongoDB
|
||||||
from pymongo.operations import SearchIndexModel
|
from pymongo.operations import SearchIndexModel
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -16,12 +15,11 @@ def mongo_vector_fixture(mock_mongo_client):
|
|||||||
mock_collection.find.return_value = []
|
mock_collection.find.return_value = []
|
||||||
mock_db.list_collection_names.return_value = []
|
mock_db.list_collection_names.return_value = []
|
||||||
|
|
||||||
mongo_vector = MongoVector(
|
mongo_vector = MongoDB(
|
||||||
db_name="test_db",
|
db_name="test_db",
|
||||||
collection_name="test_collection",
|
collection_name="test_collection",
|
||||||
embedding_model_dims=1536,
|
embedding_model_dims=1536,
|
||||||
user="username",
|
mongo_uri="mongodb://username:password@localhost:27017"
|
||||||
password="password",
|
|
||||||
)
|
)
|
||||||
return mongo_vector, mock_collection, mock_db
|
return mongo_vector, mock_collection, mock_db
|
||||||
|
|
||||||
@@ -48,7 +46,7 @@ def test_initalize_create_col(mongo_vector_fixture):
|
|||||||
"fields": {
|
"fields": {
|
||||||
"embedding": {
|
"embedding": {
|
||||||
"type": "knnVector",
|
"type": "knnVector",
|
||||||
"d": 1536,
|
"dimensions": 1536,
|
||||||
"similarity": "cosine",
|
"similarity": "cosine",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user