Feature: baidu vector db integration (#2929)

This commit is contained in:
Shili Cao
2025-06-19 13:42:12 +08:00
committed by GitHub
parent cdee6a4ff0
commit d35065c887
7 changed files with 683 additions and 2 deletions

View File

@@ -13,7 +13,7 @@ install:
install_all:
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow
# Format code with ruff
format:

View File

@@ -0,0 +1,67 @@
---
title: Baidu VectorDB (Mochow)
---
[Baidu VectorDB](https://cloud.baidu.com/doc/VDB/index.html) is an enterprise-level distributed vector database service developed by Baidu Intelligent Cloud. It is powered by Baidu's proprietary "Mochow" vector database kernel, providing high performance, availability, and security for vector search.
### Usage
```python
import os
from mem0 import Memory
config = {
"vector_store": {
"provider": "baidu",
"config": {
"endpoint": "http://your-mochow-endpoint:8287",
"account": "root",
"api_key": "your-api-key",
"database_name": "mem0",
"table_name": "mem0_table",
"embedding_model_dims": 1536,
"metric_type": "COSINE"
}
}
}
m = Memory.from_config(config)
messages = [
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
{"role": "assistant", "content": "How about a thriller movie? They can be quite engaging."},
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
]
m.add(messages, user_id="alice", metadata={"category": "movies"})
```
### Config
Here are the available parameters for the `mochow` config:
| Parameter | Description | Default Value |
| --- | --- | --- |
| `endpoint` | Endpoint URL for your Baidu VectorDB instance | Required |
| `account` | Baidu VectorDB account name | `root` |
| `api_key` | API key for accessing Baidu VectorDB | Required |
| `database_name` | Name of the database | `mem0` |
| `table_name` | Name of the table | `mem0_table` |
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
| `metric_type` | Distance metric for similarity search | `L2` |
### Distance Metrics
The following distance metrics are supported:
- `L2`: Euclidean distance (default)
- `IP`: Inner product
- `COSINE`: Cosine similarity
### Index Configuration
The vector index is automatically configured with the following HNSW parameters:
- `m`: 16 (number of connections per element)
- `efconstruction`: 200 (size of the dynamic candidate list)
- `auto_build`: true (automatically build index)
- `auto_build_index_policy`: Incremental build with 10000 rows increment

View File

@@ -146,7 +146,8 @@
"components/vectordbs/dbs/vertex_ai",
"components/vectordbs/dbs/weaviate",
"components/vectordbs/dbs/faiss",
"components/vectordbs/dbs/langchain"
"components/vectordbs/dbs/langchain",
"components/vectordbs/dbs/baidu"
]
}
]

View File

@@ -0,0 +1,30 @@
from enum import Enum
from typing import Any, Dict
from pydantic import BaseModel, Field, model_validator
class BaiduDBConfig(BaseModel):
endpoint: str = Field("http://localhost:8287", description="Endpoint URL for Baidu VectorDB")
account: str = Field("root", description="Account for Baidu VectorDB")
api_key: str = Field(None, description="API Key for Baidu VectorDB")
database_name: str = Field("mem0", description="Name of the database")
table_name: str = Field("mem0", description="Name of the table")
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
metric_type: str = Field("L2", description="Metric type for similarity search")
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = {
"arbitrary_types_allowed": True,
}

349
mem0/vector_stores/baidu.py Normal file
View File

@@ -0,0 +1,349 @@
import logging
import time
from typing import Dict, Optional
from pydantic import BaseModel
from mem0.vector_stores.base import VectorStoreBase
try:
import pymochow
from pymochow.configuration import Configuration
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.model.enum import FieldType, MetricType, IndexType, TableState, ServerErrCode
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector
from pymochow.exception import ServerError
except ImportError:
raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.")
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class BaiduDB(VectorStoreBase):
def __init__(
self,
endpoint: str,
account: str,
api_key: str,
database_name: str,
table_name: str,
embedding_model_dims: int,
metric_type: MetricType,
) -> None:
"""Initialize the BaiduDB database.
Args:
endpoint (str): Endpoint URL for Baidu VectorDB.
account (str): Account for Baidu VectorDB.
api_key (str): API Key for Baidu VectorDB.
database_name (str): Name of the database.
table_name (str): Name of the table.
embedding_model_dims (int): Dimensions of the embedding model.
metric_type (MetricType): Metric type for similarity search.
"""
self.endpoint = endpoint
self.account = account
self.api_key = api_key
self.database_name = database_name
self.table_name = table_name
self.embedding_model_dims = embedding_model_dims
self.metric_type = metric_type
# Initialize Mochow client
config = Configuration(credentials=BceCredentials(account, api_key), endpoint=endpoint)
self.client = pymochow.MochowClient(config)
# Ensure database and table exist
self._create_database_if_not_exists()
self.create_col(
name=self.table_name,
vector_size=self.embedding_model_dims,
distance=self.metric_type,
)
def _create_database_if_not_exists(self):
"""Create database if it doesn't exist."""
try:
# Check if database exists
databases = self.client.list_databases()
db_exists = any(db.database_name == self.database_name for db in databases)
if not db_exists:
self._database = self.client.create_database(self.database_name)
logger.info(f"Created database: {self.database_name}")
else:
self._database = self.client.database(self.database_name)
logger.info(f"Database {self.database_name} already exists")
except Exception as e:
logger.error(f"Error creating database: {e}")
raise
def create_col(self, name, vector_size, distance):
"""Create a new table.
Args:
name (str): Name of the table to create.
vector_size (int): Dimension of the vector.
distance (str): Metric type for similarity search.
"""
# Check if table already exists
try:
tables = self._database.list_table()
table_exists = any(table.table_name == name for table in tables)
if table_exists:
logger.info(f"Table {name} already exists. Skipping creation.")
self._table = self._database.describe_table(name)
return
# Convert distance string to MetricType enum
metric_type = None
for k, v in MetricType.__members__.items():
if k == distance:
metric_type = v
if metric_type is None:
raise ValueError(f"Unsupported metric_type: {distance}")
# Define table schema
fields = [
Field(
"id", FieldType.STRING, primary_key=True, partition_key=True, auto_increment=False, not_null=True
),
Field("vector", FieldType.FLOAT_VECTOR, dimension=vector_size),
Field("metadata", FieldType.JSON),
]
# Create vector index
indexes = [
VectorIndex(
index_name="vector_idx",
index_type=IndexType.HNSW,
field="vector",
metric_type=metric_type,
params=HNSWParams(m=16, efconstruction=200),
auto_build=True,
auto_build_index_policy=AutoBuildRowCountIncrement(row_count_increment=10000),
),
FilteringIndex(index_name="metadata_filtering_idx", fields=["metadata"]),
]
schema = Schema(fields=fields, indexes=indexes)
# Create table
self._table = self._database.create_table(
table_name=name, replication=3, partition=Partition(partition_num=1), schema=schema
)
logger.info(f"Created table: {name}")
# Wait for table to be ready
while True:
time.sleep(2)
table = self._database.describe_table(name)
if table.state == TableState.NORMAL:
logger.info(f"Table {name} is ready.")
break
logger.info(f"Waiting for table {name} to be ready, current state: {table.state}")
self._table = table
except Exception as e:
logger.error(f"Error creating table: {e}")
raise
def insert(self, vectors, payloads=None, ids=None):
"""Insert vectors into the table.
Args:
vectors (List[List[float]]): List of vectors to insert.
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
# Prepare data for insertion
for idx, vector, metadata in zip(ids, vectors, payloads):
row = Row(id=idx, vector=vector, metadata=metadata)
self._table.upsert(rows=[row])
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
"""
Search for similar vectors.
Args:
query (str): Query string.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: Search results.
"""
# Add filters if provided
search_filter = None
if filters:
search_filter = self._create_filter(filters)
# Create AnnSearch for vector search
request = VectorTopkSearchRequest(
vector_field="vector",
vector=FloatVector(vectors),
limit=limit,
filter=search_filter,
config=VectorSearchConfig(ef=200),
)
# Perform search
projections = ["id", "metadata"]
res = self._table.vector_search(request=request, projections=projections)
# Parse results
output = []
for row in res.rows:
row_data = row.get("row", {})
output_data = OutputData(
id=row_data.get("id"), score=row.get("score", 0.0), payload=row_data.get("metadata", {})
)
output.append(output_data)
return output
def delete(self, vector_id):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
self._table.delete(primary_key={"id": vector_id})
def update(self, vector_id=None, vector=None, payload=None):
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
row = Row(id=vector_id, vector=vector, metadata=payload)
self._table.upsert(rows=[row])
def get(self, vector_id):
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Retrieved vector.
"""
projections = ["id", "metadata"]
result = self._table.query(primary_key={"id": vector_id}, projections=projections)
row = result.row
return OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
def list_cols(self):
"""
List all tables (collections).
Returns:
List[str]: List of table names.
"""
tables = self._database.list_table()
return [table.table_name for table in tables]
def delete_col(self):
"""Delete the table."""
try:
tables = self._database.list_table()
# skip drop table if table not exists
table_exists = any(table.table_name == self.table_name for table in tables)
if not table_exists:
logger.info(f"Table {self.table_name} does not exist, skipping deletion")
return
# Delete the table
self._database.drop_table(self.table_name)
logger.info(f"Initiated deletion of table {self.table_name}")
# Wait for table to be completely deleted
while True:
time.sleep(2)
try:
self._database.describe_table(self.table_name)
logger.info(f"Waiting for table {self.table_name} to be deleted...")
except ServerError as e:
if e.code == ServerErrCode.TABLE_NOT_EXIST:
logger.info(f"Table {self.table_name} has been completely deleted")
break
logger.error(f"Error checking table status: {e}")
raise
except Exception as e:
logger.error(f"Error deleting table: {e}")
raise
def col_info(self):
"""
Get information about the table.
Returns:
Dict[str, Any]: Table information.
"""
return self._table.stats()
def list(self, filters: dict = None, limit: int = 100) -> list:
"""
List all vectors in the table.
Args:
filters (Dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors.
"""
projections = ["id", "metadata"]
list_filter = self._create_filter(filters) if filters else None
result = self._table.select(filter=list_filter, projections=projections, limit=limit)
memories = []
for row in result.rows:
obj = OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
memories.append(obj)
return [memories]
def reset(self):
"""Reset the table by deleting and recreating it."""
logger.warning(f"Resetting table {self.table_name}...")
try:
self.delete_col()
self.create_col(
name=self.table_name,
vector_size=self.embedding_model_dims,
distance=self.metric_type,
)
except Exception as e:
logger.warning(f"Error resetting table: {e}")
raise
def _create_filter(self, filters: dict) -> str:
"""
Create filter expression for queries.
Args:
filters (dict): Filter conditions.
Returns:
str: Filter expression.
"""
conditions = []
for key, value in filters.items():
if isinstance(value, str):
conditions.append(f'metadata["{key}"] = "{value}"')
else:
conditions.append(f'metadata["{key}"] = {value}')
return " AND ".join(conditions)

View File

@@ -17,6 +17,7 @@ class VectorStoreConfig(BaseModel):
"pinecone": "PineconeConfig",
"mongodb": "MongoDBConfig",
"milvus": "MilvusDBConfig",
"baidu": "BaiduDBConfig",
"upstash_vector": "UpstashVectorConfig",
"azure_ai_search": "AzureAISearchConfig",
"redis": "RedisDBConfig",

View File

@@ -0,0 +1,233 @@
from unittest.mock import Mock, patch, PropertyMock
import pytest
from mem0.vector_stores.baidu import BaiduDB, OutputData
from pymochow.model.enum import MetricType, TableState, ServerErrCode
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
from pymochow.exception import ServerError
@pytest.fixture
def mock_mochow_client():
with patch("pymochow.MochowClient") as mock_client:
yield mock_client
@pytest.fixture
def mock_configuration():
with patch("pymochow.configuration.Configuration") as mock_config:
yield mock_config
@pytest.fixture
def mock_bce_credentials():
with patch("pymochow.auth.bce_credentials.BceCredentials") as mock_creds:
yield mock_creds
@pytest.fixture
def mock_table():
mock_table = Mock(spec=Table)
# 设置 Table 类的属性
type(mock_table).database_name = PropertyMock(return_value="test_db")
type(mock_table).table_name = PropertyMock(return_value="test_table")
type(mock_table).schema = PropertyMock(return_value=Mock())
type(mock_table).replication = PropertyMock(return_value=1)
type(mock_table).partition = PropertyMock(return_value=Mock())
type(mock_table).enable_dynamic_field = PropertyMock(return_value=False)
type(mock_table).description = PropertyMock(return_value="")
type(mock_table).create_time = PropertyMock(return_value="")
type(mock_table).state = PropertyMock(return_value=TableState.NORMAL)
type(mock_table).aliases = PropertyMock(return_value=[])
return mock_table
@pytest.fixture
def mochow_instance(mock_mochow_client, mock_configuration, mock_bce_credentials, mock_table):
mock_database = Mock()
mock_client_instance = Mock()
# Mock the client creation
mock_mochow_client.return_value = mock_client_instance
# Mock database operations
mock_client_instance.list_databases.return_value = []
mock_client_instance.create_database.return_value = mock_database
mock_client_instance.database.return_value = mock_database
# Mock table operations
mock_database.list_table.return_value = []
mock_database.create_table.return_value = mock_table
mock_database.describe_table.return_value = Mock(state=TableState.NORMAL)
mock_database.table.return_value = mock_table
return BaiduDB(
endpoint="http://localhost:8287",
account="test_account",
api_key="test_api_key",
database_name="test_db",
table_name="test_table",
embedding_model_dims=128,
metric_type="COSINE",
)
def test_insert(mochow_instance, mock_mochow_client):
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
payloads = [{"name": "vector1"}, {"name": "vector2"}]
ids = ["id1", "id2"]
mochow_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
# Verify table.upsert was called with correct data
assert mochow_instance._table.upsert.call_count == 2
calls = mochow_instance._table.upsert.call_args_list
# Check first call
first_row = calls[0][1]["rows"][0]
assert first_row._data["id"] == "id1"
assert first_row._data["vector"] == [0.1, 0.2, 0.3]
assert first_row._data["metadata"] == {"name": "vector1"}
# Check second call
second_row = calls[1][1]["rows"][0]
assert second_row._data["id"] == "id2"
assert second_row._data["vector"] == [0.4, 0.5, 0.6]
assert second_row._data["metadata"] == {"name": "vector2"}
def test_search(mochow_instance, mock_mochow_client):
# Mock search results
mock_search_results = Mock()
mock_search_results.rows = [
{"row": {"id": "id1", "metadata": {"name": "vector1"}}, "score": 0.1},
{"row": {"id": "id2", "metadata": {"name": "vector2"}}, "score": 0.2},
]
mochow_instance._table.vector_search.return_value = mock_search_results
vectors = [0.1, 0.2, 0.3]
results = mochow_instance.search(query="test", vectors=vectors, limit=2)
# Verify search was called with correct parameters
mochow_instance._table.vector_search.assert_called_once()
call_args = mochow_instance._table.vector_search.call_args
request = call_args[0][0] if call_args[0] else call_args[1]["request"]
assert isinstance(request, VectorTopkSearchRequest)
assert request._vector_field == "vector"
assert isinstance(request._vector, FloatVector)
assert request._vector._floats == vectors
assert request._limit == 2
assert isinstance(request._config, VectorSearchConfig)
assert request._config._ef == 200
# Verify results
assert len(results) == 2
assert results[0].id == "id1"
assert results[0].score == 0.1
assert results[0].payload == {"name": "vector1"}
assert results[1].id == "id2"
assert results[1].score == 0.2
assert results[1].payload == {"name": "vector2"}
def test_search_with_filters(mochow_instance, mock_mochow_client):
mochow_instance._table.vector_search.return_value = Mock(rows=[])
vectors = [0.1, 0.2, 0.3]
filters = {"user_id": "user123", "agent_id": "agent456"}
mochow_instance.search(query="test", vectors=vectors, limit=2, filters=filters)
# Verify search was called with filter
call_args = mochow_instance._table.vector_search.call_args
request = call_args[0][0] if call_args[0] else call_args[1]["request"]
assert request._filter == 'metadata["user_id"] = "user123" AND metadata["agent_id"] = "agent456"'
def test_delete(mochow_instance, mock_mochow_client):
vector_id = "id1"
mochow_instance.delete(vector_id=vector_id)
mochow_instance._table.delete.assert_called_once_with(primary_key={"id": vector_id})
def test_update(mochow_instance, mock_mochow_client):
vector_id = "id1"
new_vector = [0.7, 0.8, 0.9]
new_payload = {"name": "updated_vector"}
mochow_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload)
mochow_instance._table.upsert.assert_called_once()
call_args = mochow_instance._table.upsert.call_args
row = call_args[0][0] if call_args[0] else call_args[1]["rows"][0]
assert row._data["id"] == vector_id
assert row._data["vector"] == new_vector
assert row._data["metadata"] == new_payload
def test_get(mochow_instance, mock_mochow_client):
# Mock query result
mock_result = Mock()
mock_result.row = {"id": "id1", "metadata": {"name": "vector1"}}
mochow_instance._table.query.return_value = mock_result
result = mochow_instance.get(vector_id="id1")
mochow_instance._table.query.assert_called_once_with(primary_key={"id": "id1"}, projections=["id", "metadata"])
assert result.id == "id1"
assert result.score is None
assert result.payload == {"name": "vector1"}
def test_list(mochow_instance, mock_mochow_client):
# Mock select result
mock_result = Mock()
mock_result.rows = [{"id": "id1", "metadata": {"name": "vector1"}}, {"id": "id2", "metadata": {"name": "vector2"}}]
mochow_instance._table.select.return_value = mock_result
results = mochow_instance.list(limit=2)
mochow_instance._table.select.assert_called_once_with(filter=None, projections=["id", "metadata"], limit=2)
assert len(results[0]) == 2
assert results[0][0].id == "id1"
assert results[0][1].id == "id2"
def test_list_cols(mochow_instance, mock_mochow_client):
# Mock table list
mock_tables = [
Mock(spec=Table, database_name="test_db", table_name="table1"),
Mock(spec=Table, database_name="test_db", table_name="table2"),
]
mochow_instance._database.list_table.return_value = mock_tables
result = mochow_instance.list_cols()
assert result == ["table1", "table2"]
def test_delete_col_not_exists(mochow_instance, mock_mochow_client):
# 使用正确的 ServerErrCode 枚举值
mochow_instance._database.drop_table.side_effect = ServerError(
"Table not exists", code=ServerErrCode.TABLE_NOT_EXIST
)
# Should not raise exception
mochow_instance.delete_col()
def test_col_info(mochow_instance, mock_mochow_client):
mock_table_info = {"table_name": "test_table", "fields": []}
mochow_instance._table.stats.return_value = mock_table_info
result = mochow_instance.col_info()
assert result == mock_table_info