diff --git a/Makefile b/Makefile index 59c745ca..a01c84a5 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ install: install_all: pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \ google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \ - upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 + upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow # Format code with ruff format: diff --git a/docs/components/vectordbs/dbs/baidu.mdx b/docs/components/vectordbs/dbs/baidu.mdx new file mode 100644 index 00000000..457fff2b --- /dev/null +++ b/docs/components/vectordbs/dbs/baidu.mdx @@ -0,0 +1,67 @@ +--- +title: Baidu VectorDB (Mochow) +--- + +[Baidu VectorDB](https://cloud.baidu.com/doc/VDB/index.html) is an enterprise-level distributed vector database service developed by Baidu Intelligent Cloud. It is powered by Baidu's proprietary "Mochow" vector database kernel, providing high performance, availability, and security for vector search. + +### Usage + +```python +import os +from mem0 import Memory + +config = { + "vector_store": { + "provider": "baidu", + "config": { + "endpoint": "http://your-mochow-endpoint:8287", + "account": "root", + "api_key": "your-api-key", + "database_name": "mem0", + "table_name": "mem0_table", + "embedding_model_dims": 1536, + "metric_type": "COSINE" + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movie? They can be quite engaging."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alice", metadata={"category": "movies"}) +``` + +### Config + +Here are the available parameters for the `mochow` config: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| `endpoint` | Endpoint URL for your Baidu VectorDB instance | Required | +| `account` | Baidu VectorDB account name | `root` | +| `api_key` | API key for accessing Baidu VectorDB | Required | +| `database_name` | Name of the database | `mem0` | +| `table_name` | Name of the table | `mem0_table` | +| `embedding_model_dims` | Dimensions of the embedding model | `1536` | +| `metric_type` | Distance metric for similarity search | `L2` | + +### Distance Metrics + +The following distance metrics are supported: + +- `L2`: Euclidean distance (default) +- `IP`: Inner product +- `COSINE`: Cosine similarity + +### Index Configuration + +The vector index is automatically configured with the following HNSW parameters: + +- `m`: 16 (number of connections per element) +- `efconstruction`: 200 (size of the dynamic candidate list) +- `auto_build`: true (automatically build index) +- `auto_build_index_policy`: Incremental build with 10000 rows increment diff --git a/docs/docs.json b/docs/docs.json index ead405a0..47dd1add 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -146,7 +146,8 @@ "components/vectordbs/dbs/vertex_ai", "components/vectordbs/dbs/weaviate", "components/vectordbs/dbs/faiss", - "components/vectordbs/dbs/langchain" + "components/vectordbs/dbs/langchain", + "components/vectordbs/dbs/baidu" ] } ] diff --git a/mem0/configs/vector_stores/baidu.py b/mem0/configs/vector_stores/baidu.py new file mode 100644 index 00000000..7a1ca6a4 --- /dev/null +++ b/mem0/configs/vector_stores/baidu.py @@ -0,0 +1,30 @@ +from enum import Enum +from typing import Any, Dict + +from pydantic import BaseModel, Field, model_validator + + +class BaiduDBConfig(BaseModel): + endpoint: str = Field("http://localhost:8287", description="Endpoint URL for Baidu VectorDB") + account: str = Field("root", description="Account for Baidu VectorDB") + api_key: str = Field(None, description="API Key for Baidu VectorDB") + database_name: str = Field("mem0", description="Name of the database") + table_name: str = Field("mem0", description="Name of the table") + embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model") + metric_type: str = Field("L2", description="Metric type for similarity search") + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + ) + return values + + model_config = { + "arbitrary_types_allowed": True, + } diff --git a/mem0/vector_stores/baidu.py b/mem0/vector_stores/baidu.py new file mode 100644 index 00000000..0a3ed139 --- /dev/null +++ b/mem0/vector_stores/baidu.py @@ -0,0 +1,349 @@ +import logging +import time +from typing import Dict, Optional + +from pydantic import BaseModel + +from mem0.vector_stores.base import VectorStoreBase + +try: + import pymochow + from pymochow.configuration import Configuration + from pymochow.auth.bce_credentials import BceCredentials + from pymochow.model.enum import FieldType, MetricType, IndexType, TableState, ServerErrCode + from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement + from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector + from pymochow.exception import ServerError +except ImportError: + raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.") + +logger = logging.getLogger(__name__) + + +class OutputData(BaseModel): + id: Optional[str] # memory id + score: Optional[float] # distance + payload: Optional[Dict] # metadata + + +class BaiduDB(VectorStoreBase): + def __init__( + self, + endpoint: str, + account: str, + api_key: str, + database_name: str, + table_name: str, + embedding_model_dims: int, + metric_type: MetricType, + ) -> None: + """Initialize the BaiduDB database. + + Args: + endpoint (str): Endpoint URL for Baidu VectorDB. + account (str): Account for Baidu VectorDB. + api_key (str): API Key for Baidu VectorDB. + database_name (str): Name of the database. + table_name (str): Name of the table. + embedding_model_dims (int): Dimensions of the embedding model. + metric_type (MetricType): Metric type for similarity search. + """ + self.endpoint = endpoint + self.account = account + self.api_key = api_key + self.database_name = database_name + self.table_name = table_name + self.embedding_model_dims = embedding_model_dims + self.metric_type = metric_type + + # Initialize Mochow client + config = Configuration(credentials=BceCredentials(account, api_key), endpoint=endpoint) + self.client = pymochow.MochowClient(config) + + # Ensure database and table exist + self._create_database_if_not_exists() + self.create_col( + name=self.table_name, + vector_size=self.embedding_model_dims, + distance=self.metric_type, + ) + + def _create_database_if_not_exists(self): + """Create database if it doesn't exist.""" + try: + # Check if database exists + databases = self.client.list_databases() + db_exists = any(db.database_name == self.database_name for db in databases) + if not db_exists: + self._database = self.client.create_database(self.database_name) + logger.info(f"Created database: {self.database_name}") + else: + self._database = self.client.database(self.database_name) + logger.info(f"Database {self.database_name} already exists") + except Exception as e: + logger.error(f"Error creating database: {e}") + raise + + def create_col(self, name, vector_size, distance): + """Create a new table. + + Args: + name (str): Name of the table to create. + vector_size (int): Dimension of the vector. + distance (str): Metric type for similarity search. + """ + # Check if table already exists + try: + tables = self._database.list_table() + table_exists = any(table.table_name == name for table in tables) + if table_exists: + logger.info(f"Table {name} already exists. Skipping creation.") + self._table = self._database.describe_table(name) + return + + # Convert distance string to MetricType enum + metric_type = None + for k, v in MetricType.__members__.items(): + if k == distance: + metric_type = v + if metric_type is None: + raise ValueError(f"Unsupported metric_type: {distance}") + + # Define table schema + fields = [ + Field( + "id", FieldType.STRING, primary_key=True, partition_key=True, auto_increment=False, not_null=True + ), + Field("vector", FieldType.FLOAT_VECTOR, dimension=vector_size), + Field("metadata", FieldType.JSON), + ] + + # Create vector index + indexes = [ + VectorIndex( + index_name="vector_idx", + index_type=IndexType.HNSW, + field="vector", + metric_type=metric_type, + params=HNSWParams(m=16, efconstruction=200), + auto_build=True, + auto_build_index_policy=AutoBuildRowCountIncrement(row_count_increment=10000), + ), + FilteringIndex(index_name="metadata_filtering_idx", fields=["metadata"]), + ] + + schema = Schema(fields=fields, indexes=indexes) + + # Create table + self._table = self._database.create_table( + table_name=name, replication=3, partition=Partition(partition_num=1), schema=schema + ) + logger.info(f"Created table: {name}") + + # Wait for table to be ready + while True: + time.sleep(2) + table = self._database.describe_table(name) + if table.state == TableState.NORMAL: + logger.info(f"Table {name} is ready.") + break + logger.info(f"Waiting for table {name} to be ready, current state: {table.state}") + self._table = table + except Exception as e: + logger.error(f"Error creating table: {e}") + raise + + def insert(self, vectors, payloads=None, ids=None): + """Insert vectors into the table. + + Args: + vectors (List[List[float]]): List of vectors to insert. + payloads (List[Dict], optional): List of payloads corresponding to vectors. + ids (List[str], optional): List of IDs corresponding to vectors. + """ + # Prepare data for insertion + for idx, vector, metadata in zip(ids, vectors, payloads): + row = Row(id=idx, vector=vector, metadata=metadata) + self._table.upsert(rows=[row]) + + def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list: + """ + Search for similar vectors. + + Args: + query (str): Query string. + vectors (List[float]): Query vector. + limit (int, optional): Number of results to return. Defaults to 5. + filters (Dict, optional): Filters to apply to the search. Defaults to None. + + Returns: + list: Search results. + """ + # Add filters if provided + search_filter = None + if filters: + search_filter = self._create_filter(filters) + + # Create AnnSearch for vector search + request = VectorTopkSearchRequest( + vector_field="vector", + vector=FloatVector(vectors), + limit=limit, + filter=search_filter, + config=VectorSearchConfig(ef=200), + ) + + # Perform search + projections = ["id", "metadata"] + res = self._table.vector_search(request=request, projections=projections) + + # Parse results + output = [] + for row in res.rows: + row_data = row.get("row", {}) + output_data = OutputData( + id=row_data.get("id"), score=row.get("score", 0.0), payload=row_data.get("metadata", {}) + ) + output.append(output_data) + + return output + + def delete(self, vector_id): + """ + Delete a vector by ID. + + Args: + vector_id (str): ID of the vector to delete. + """ + self._table.delete(primary_key={"id": vector_id}) + + def update(self, vector_id=None, vector=None, payload=None): + """ + Update a vector and its payload. + + Args: + vector_id (str): ID of the vector to update. + vector (List[float], optional): Updated vector. + payload (Dict, optional): Updated payload. + """ + row = Row(id=vector_id, vector=vector, metadata=payload) + self._table.upsert(rows=[row]) + + def get(self, vector_id): + """ + Retrieve a vector by ID. + + Args: + vector_id (str): ID of the vector to retrieve. + + Returns: + OutputData: Retrieved vector. + """ + projections = ["id", "metadata"] + result = self._table.query(primary_key={"id": vector_id}, projections=projections) + row = result.row + return OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {})) + + def list_cols(self): + """ + List all tables (collections). + + Returns: + List[str]: List of table names. + """ + tables = self._database.list_table() + return [table.table_name for table in tables] + + def delete_col(self): + """Delete the table.""" + try: + tables = self._database.list_table() + + # skip drop table if table not exists + table_exists = any(table.table_name == self.table_name for table in tables) + if not table_exists: + logger.info(f"Table {self.table_name} does not exist, skipping deletion") + return + + # Delete the table + self._database.drop_table(self.table_name) + logger.info(f"Initiated deletion of table {self.table_name}") + + # Wait for table to be completely deleted + while True: + time.sleep(2) + try: + self._database.describe_table(self.table_name) + logger.info(f"Waiting for table {self.table_name} to be deleted...") + except ServerError as e: + if e.code == ServerErrCode.TABLE_NOT_EXIST: + logger.info(f"Table {self.table_name} has been completely deleted") + break + logger.error(f"Error checking table status: {e}") + raise + except Exception as e: + logger.error(f"Error deleting table: {e}") + raise + + def col_info(self): + """ + Get information about the table. + + Returns: + Dict[str, Any]: Table information. + """ + return self._table.stats() + + def list(self, filters: dict = None, limit: int = 100) -> list: + """ + List all vectors in the table. + + Args: + filters (Dict, optional): Filters to apply to the list. + limit (int, optional): Number of vectors to return. Defaults to 100. + + Returns: + List[OutputData]: List of vectors. + """ + projections = ["id", "metadata"] + list_filter = self._create_filter(filters) if filters else None + result = self._table.select(filter=list_filter, projections=projections, limit=limit) + + memories = [] + for row in result.rows: + obj = OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {})) + memories.append(obj) + + return [memories] + + def reset(self): + """Reset the table by deleting and recreating it.""" + logger.warning(f"Resetting table {self.table_name}...") + try: + self.delete_col() + self.create_col( + name=self.table_name, + vector_size=self.embedding_model_dims, + distance=self.metric_type, + ) + except Exception as e: + logger.warning(f"Error resetting table: {e}") + raise + + def _create_filter(self, filters: dict) -> str: + """ + Create filter expression for queries. + + Args: + filters (dict): Filter conditions. + + Returns: + str: Filter expression. + """ + conditions = [] + for key, value in filters.items(): + if isinstance(value, str): + conditions.append(f'metadata["{key}"] = "{value}"') + else: + conditions.append(f'metadata["{key}"] = {value}') + return " AND ".join(conditions) diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index e360d238..7e03a6f8 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -17,6 +17,7 @@ class VectorStoreConfig(BaseModel): "pinecone": "PineconeConfig", "mongodb": "MongoDBConfig", "milvus": "MilvusDBConfig", + "baidu": "BaiduDBConfig", "upstash_vector": "UpstashVectorConfig", "azure_ai_search": "AzureAISearchConfig", "redis": "RedisDBConfig", diff --git a/tests/vector_stores/test_baidu.py b/tests/vector_stores/test_baidu.py new file mode 100644 index 00000000..c5ef5734 --- /dev/null +++ b/tests/vector_stores/test_baidu.py @@ -0,0 +1,233 @@ +from unittest.mock import Mock, patch, PropertyMock + +import pytest + +from mem0.vector_stores.baidu import BaiduDB, OutputData +from pymochow.model.enum import MetricType, TableState, ServerErrCode +from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement +from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table +from pymochow.exception import ServerError + + +@pytest.fixture +def mock_mochow_client(): + with patch("pymochow.MochowClient") as mock_client: + yield mock_client + + +@pytest.fixture +def mock_configuration(): + with patch("pymochow.configuration.Configuration") as mock_config: + yield mock_config + + +@pytest.fixture +def mock_bce_credentials(): + with patch("pymochow.auth.bce_credentials.BceCredentials") as mock_creds: + yield mock_creds + + +@pytest.fixture +def mock_table(): + mock_table = Mock(spec=Table) + # 设置 Table 类的属性 + type(mock_table).database_name = PropertyMock(return_value="test_db") + type(mock_table).table_name = PropertyMock(return_value="test_table") + type(mock_table).schema = PropertyMock(return_value=Mock()) + type(mock_table).replication = PropertyMock(return_value=1) + type(mock_table).partition = PropertyMock(return_value=Mock()) + type(mock_table).enable_dynamic_field = PropertyMock(return_value=False) + type(mock_table).description = PropertyMock(return_value="") + type(mock_table).create_time = PropertyMock(return_value="") + type(mock_table).state = PropertyMock(return_value=TableState.NORMAL) + type(mock_table).aliases = PropertyMock(return_value=[]) + return mock_table + + +@pytest.fixture +def mochow_instance(mock_mochow_client, mock_configuration, mock_bce_credentials, mock_table): + mock_database = Mock() + mock_client_instance = Mock() + + # Mock the client creation + mock_mochow_client.return_value = mock_client_instance + + # Mock database operations + mock_client_instance.list_databases.return_value = [] + mock_client_instance.create_database.return_value = mock_database + mock_client_instance.database.return_value = mock_database + + # Mock table operations + mock_database.list_table.return_value = [] + mock_database.create_table.return_value = mock_table + mock_database.describe_table.return_value = Mock(state=TableState.NORMAL) + mock_database.table.return_value = mock_table + + return BaiduDB( + endpoint="http://localhost:8287", + account="test_account", + api_key="test_api_key", + database_name="test_db", + table_name="test_table", + embedding_model_dims=128, + metric_type="COSINE", + ) + + +def test_insert(mochow_instance, mock_mochow_client): + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [{"name": "vector1"}, {"name": "vector2"}] + ids = ["id1", "id2"] + + mochow_instance.insert(vectors=vectors, payloads=payloads, ids=ids) + + # Verify table.upsert was called with correct data + assert mochow_instance._table.upsert.call_count == 2 + calls = mochow_instance._table.upsert.call_args_list + + # Check first call + first_row = calls[0][1]["rows"][0] + assert first_row._data["id"] == "id1" + assert first_row._data["vector"] == [0.1, 0.2, 0.3] + assert first_row._data["metadata"] == {"name": "vector1"} + + # Check second call + second_row = calls[1][1]["rows"][0] + assert second_row._data["id"] == "id2" + assert second_row._data["vector"] == [0.4, 0.5, 0.6] + assert second_row._data["metadata"] == {"name": "vector2"} + + +def test_search(mochow_instance, mock_mochow_client): + # Mock search results + mock_search_results = Mock() + mock_search_results.rows = [ + {"row": {"id": "id1", "metadata": {"name": "vector1"}}, "score": 0.1}, + {"row": {"id": "id2", "metadata": {"name": "vector2"}}, "score": 0.2}, + ] + mochow_instance._table.vector_search.return_value = mock_search_results + + vectors = [0.1, 0.2, 0.3] + results = mochow_instance.search(query="test", vectors=vectors, limit=2) + + # Verify search was called with correct parameters + mochow_instance._table.vector_search.assert_called_once() + call_args = mochow_instance._table.vector_search.call_args + request = call_args[0][0] if call_args[0] else call_args[1]["request"] + + assert isinstance(request, VectorTopkSearchRequest) + assert request._vector_field == "vector" + assert isinstance(request._vector, FloatVector) + assert request._vector._floats == vectors + assert request._limit == 2 + assert isinstance(request._config, VectorSearchConfig) + assert request._config._ef == 200 + + # Verify results + assert len(results) == 2 + assert results[0].id == "id1" + assert results[0].score == 0.1 + assert results[0].payload == {"name": "vector1"} + assert results[1].id == "id2" + assert results[1].score == 0.2 + assert results[1].payload == {"name": "vector2"} + + +def test_search_with_filters(mochow_instance, mock_mochow_client): + mochow_instance._table.vector_search.return_value = Mock(rows=[]) + + vectors = [0.1, 0.2, 0.3] + filters = {"user_id": "user123", "agent_id": "agent456"} + + mochow_instance.search(query="test", vectors=vectors, limit=2, filters=filters) + + # Verify search was called with filter + call_args = mochow_instance._table.vector_search.call_args + request = call_args[0][0] if call_args[0] else call_args[1]["request"] + + assert request._filter == 'metadata["user_id"] = "user123" AND metadata["agent_id"] = "agent456"' + + +def test_delete(mochow_instance, mock_mochow_client): + vector_id = "id1" + mochow_instance.delete(vector_id=vector_id) + + mochow_instance._table.delete.assert_called_once_with(primary_key={"id": vector_id}) + + +def test_update(mochow_instance, mock_mochow_client): + vector_id = "id1" + new_vector = [0.7, 0.8, 0.9] + new_payload = {"name": "updated_vector"} + + mochow_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload) + + mochow_instance._table.upsert.assert_called_once() + call_args = mochow_instance._table.upsert.call_args + row = call_args[0][0] if call_args[0] else call_args[1]["rows"][0] + + assert row._data["id"] == vector_id + assert row._data["vector"] == new_vector + assert row._data["metadata"] == new_payload + + +def test_get(mochow_instance, mock_mochow_client): + # Mock query result + mock_result = Mock() + mock_result.row = {"id": "id1", "metadata": {"name": "vector1"}} + mochow_instance._table.query.return_value = mock_result + + result = mochow_instance.get(vector_id="id1") + + mochow_instance._table.query.assert_called_once_with(primary_key={"id": "id1"}, projections=["id", "metadata"]) + + assert result.id == "id1" + assert result.score is None + assert result.payload == {"name": "vector1"} + + +def test_list(mochow_instance, mock_mochow_client): + # Mock select result + mock_result = Mock() + mock_result.rows = [{"id": "id1", "metadata": {"name": "vector1"}}, {"id": "id2", "metadata": {"name": "vector2"}}] + mochow_instance._table.select.return_value = mock_result + + results = mochow_instance.list(limit=2) + + mochow_instance._table.select.assert_called_once_with(filter=None, projections=["id", "metadata"], limit=2) + + assert len(results[0]) == 2 + assert results[0][0].id == "id1" + assert results[0][1].id == "id2" + + +def test_list_cols(mochow_instance, mock_mochow_client): + # Mock table list + mock_tables = [ + Mock(spec=Table, database_name="test_db", table_name="table1"), + Mock(spec=Table, database_name="test_db", table_name="table2"), + ] + mochow_instance._database.list_table.return_value = mock_tables + + result = mochow_instance.list_cols() + + assert result == ["table1", "table2"] + + +def test_delete_col_not_exists(mochow_instance, mock_mochow_client): + # 使用正确的 ServerErrCode 枚举值 + mochow_instance._database.drop_table.side_effect = ServerError( + "Table not exists", code=ServerErrCode.TABLE_NOT_EXIST + ) + + # Should not raise exception + mochow_instance.delete_col() + + +def test_col_info(mochow_instance, mock_mochow_client): + mock_table_info = {"table_name": "test_table", "fields": []} + mochow_instance._table.stats.return_value = mock_table_info + + result = mochow_instance.col_info() + + assert result == mock_table_info