Feature: baidu vector db integration (#2929)
This commit is contained in:
2
Makefile
2
Makefile
@@ -13,7 +13,7 @@ install:
|
|||||||
install_all:
|
install_all:
|
||||||
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
||||||
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
|
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
|
||||||
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25
|
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow
|
||||||
|
|
||||||
# Format code with ruff
|
# Format code with ruff
|
||||||
format:
|
format:
|
||||||
|
|||||||
67
docs/components/vectordbs/dbs/baidu.mdx
Normal file
67
docs/components/vectordbs/dbs/baidu.mdx
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
---
|
||||||
|
title: Baidu VectorDB (Mochow)
|
||||||
|
---
|
||||||
|
|
||||||
|
[Baidu VectorDB](https://cloud.baidu.com/doc/VDB/index.html) is an enterprise-level distributed vector database service developed by Baidu Intelligent Cloud. It is powered by Baidu's proprietary "Mochow" vector database kernel, providing high performance, availability, and security for vector search.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "baidu",
|
||||||
|
"config": {
|
||||||
|
"endpoint": "http://your-mochow-endpoint:8287",
|
||||||
|
"account": "root",
|
||||||
|
"api_key": "your-api-key",
|
||||||
|
"database_name": "mem0",
|
||||||
|
"table_name": "mem0_table",
|
||||||
|
"embedding_model_dims": 1536,
|
||||||
|
"metric_type": "COSINE"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||||
|
{"role": "assistant", "content": "How about a thriller movie? They can be quite engaging."},
|
||||||
|
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||||
|
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||||
|
]
|
||||||
|
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Config
|
||||||
|
|
||||||
|
Here are the available parameters for the `mochow` config:
|
||||||
|
|
||||||
|
| Parameter | Description | Default Value |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `endpoint` | Endpoint URL for your Baidu VectorDB instance | Required |
|
||||||
|
| `account` | Baidu VectorDB account name | `root` |
|
||||||
|
| `api_key` | API key for accessing Baidu VectorDB | Required |
|
||||||
|
| `database_name` | Name of the database | `mem0` |
|
||||||
|
| `table_name` | Name of the table | `mem0_table` |
|
||||||
|
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
||||||
|
| `metric_type` | Distance metric for similarity search | `L2` |
|
||||||
|
|
||||||
|
### Distance Metrics
|
||||||
|
|
||||||
|
The following distance metrics are supported:
|
||||||
|
|
||||||
|
- `L2`: Euclidean distance (default)
|
||||||
|
- `IP`: Inner product
|
||||||
|
- `COSINE`: Cosine similarity
|
||||||
|
|
||||||
|
### Index Configuration
|
||||||
|
|
||||||
|
The vector index is automatically configured with the following HNSW parameters:
|
||||||
|
|
||||||
|
- `m`: 16 (number of connections per element)
|
||||||
|
- `efconstruction`: 200 (size of the dynamic candidate list)
|
||||||
|
- `auto_build`: true (automatically build index)
|
||||||
|
- `auto_build_index_policy`: Incremental build with 10000 rows increment
|
||||||
@@ -146,7 +146,8 @@
|
|||||||
"components/vectordbs/dbs/vertex_ai",
|
"components/vectordbs/dbs/vertex_ai",
|
||||||
"components/vectordbs/dbs/weaviate",
|
"components/vectordbs/dbs/weaviate",
|
||||||
"components/vectordbs/dbs/faiss",
|
"components/vectordbs/dbs/faiss",
|
||||||
"components/vectordbs/dbs/langchain"
|
"components/vectordbs/dbs/langchain",
|
||||||
|
"components/vectordbs/dbs/baidu"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
30
mem0/configs/vector_stores/baidu.py
Normal file
30
mem0/configs/vector_stores/baidu.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduDBConfig(BaseModel):
|
||||||
|
endpoint: str = Field("http://localhost:8287", description="Endpoint URL for Baidu VectorDB")
|
||||||
|
account: str = Field("root", description="Account for Baidu VectorDB")
|
||||||
|
api_key: str = Field(None, description="API Key for Baidu VectorDB")
|
||||||
|
database_name: str = Field("mem0", description="Name of the database")
|
||||||
|
table_name: str = Field("mem0", description="Name of the table")
|
||||||
|
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
||||||
|
metric_type: str = Field("L2", description="Metric type for similarity search")
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
allowed_fields = set(cls.model_fields.keys())
|
||||||
|
input_fields = set(values.keys())
|
||||||
|
extra_fields = input_fields - allowed_fields
|
||||||
|
if extra_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
}
|
||||||
349
mem0/vector_stores/baidu.py
Normal file
349
mem0/vector_stores/baidu.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from mem0.vector_stores.base import VectorStoreBase
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pymochow
|
||||||
|
from pymochow.configuration import Configuration
|
||||||
|
from pymochow.auth.bce_credentials import BceCredentials
|
||||||
|
from pymochow.model.enum import FieldType, MetricType, IndexType, TableState, ServerErrCode
|
||||||
|
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
|
||||||
|
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector
|
||||||
|
from pymochow.exception import ServerError
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.")
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputData(BaseModel):
|
||||||
|
id: Optional[str] # memory id
|
||||||
|
score: Optional[float] # distance
|
||||||
|
payload: Optional[Dict] # metadata
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduDB(VectorStoreBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
account: str,
|
||||||
|
api_key: str,
|
||||||
|
database_name: str,
|
||||||
|
table_name: str,
|
||||||
|
embedding_model_dims: int,
|
||||||
|
metric_type: MetricType,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the BaiduDB database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint (str): Endpoint URL for Baidu VectorDB.
|
||||||
|
account (str): Account for Baidu VectorDB.
|
||||||
|
api_key (str): API Key for Baidu VectorDB.
|
||||||
|
database_name (str): Name of the database.
|
||||||
|
table_name (str): Name of the table.
|
||||||
|
embedding_model_dims (int): Dimensions of the embedding model.
|
||||||
|
metric_type (MetricType): Metric type for similarity search.
|
||||||
|
"""
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.account = account
|
||||||
|
self.api_key = api_key
|
||||||
|
self.database_name = database_name
|
||||||
|
self.table_name = table_name
|
||||||
|
self.embedding_model_dims = embedding_model_dims
|
||||||
|
self.metric_type = metric_type
|
||||||
|
|
||||||
|
# Initialize Mochow client
|
||||||
|
config = Configuration(credentials=BceCredentials(account, api_key), endpoint=endpoint)
|
||||||
|
self.client = pymochow.MochowClient(config)
|
||||||
|
|
||||||
|
# Ensure database and table exist
|
||||||
|
self._create_database_if_not_exists()
|
||||||
|
self.create_col(
|
||||||
|
name=self.table_name,
|
||||||
|
vector_size=self.embedding_model_dims,
|
||||||
|
distance=self.metric_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_database_if_not_exists(self):
|
||||||
|
"""Create database if it doesn't exist."""
|
||||||
|
try:
|
||||||
|
# Check if database exists
|
||||||
|
databases = self.client.list_databases()
|
||||||
|
db_exists = any(db.database_name == self.database_name for db in databases)
|
||||||
|
if not db_exists:
|
||||||
|
self._database = self.client.create_database(self.database_name)
|
||||||
|
logger.info(f"Created database: {self.database_name}")
|
||||||
|
else:
|
||||||
|
self._database = self.client.database(self.database_name)
|
||||||
|
logger.info(f"Database {self.database_name} already exists")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating database: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def create_col(self, name, vector_size, distance):
|
||||||
|
"""Create a new table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name of the table to create.
|
||||||
|
vector_size (int): Dimension of the vector.
|
||||||
|
distance (str): Metric type for similarity search.
|
||||||
|
"""
|
||||||
|
# Check if table already exists
|
||||||
|
try:
|
||||||
|
tables = self._database.list_table()
|
||||||
|
table_exists = any(table.table_name == name for table in tables)
|
||||||
|
if table_exists:
|
||||||
|
logger.info(f"Table {name} already exists. Skipping creation.")
|
||||||
|
self._table = self._database.describe_table(name)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert distance string to MetricType enum
|
||||||
|
metric_type = None
|
||||||
|
for k, v in MetricType.__members__.items():
|
||||||
|
if k == distance:
|
||||||
|
metric_type = v
|
||||||
|
if metric_type is None:
|
||||||
|
raise ValueError(f"Unsupported metric_type: {distance}")
|
||||||
|
|
||||||
|
# Define table schema
|
||||||
|
fields = [
|
||||||
|
Field(
|
||||||
|
"id", FieldType.STRING, primary_key=True, partition_key=True, auto_increment=False, not_null=True
|
||||||
|
),
|
||||||
|
Field("vector", FieldType.FLOAT_VECTOR, dimension=vector_size),
|
||||||
|
Field("metadata", FieldType.JSON),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create vector index
|
||||||
|
indexes = [
|
||||||
|
VectorIndex(
|
||||||
|
index_name="vector_idx",
|
||||||
|
index_type=IndexType.HNSW,
|
||||||
|
field="vector",
|
||||||
|
metric_type=metric_type,
|
||||||
|
params=HNSWParams(m=16, efconstruction=200),
|
||||||
|
auto_build=True,
|
||||||
|
auto_build_index_policy=AutoBuildRowCountIncrement(row_count_increment=10000),
|
||||||
|
),
|
||||||
|
FilteringIndex(index_name="metadata_filtering_idx", fields=["metadata"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
schema = Schema(fields=fields, indexes=indexes)
|
||||||
|
|
||||||
|
# Create table
|
||||||
|
self._table = self._database.create_table(
|
||||||
|
table_name=name, replication=3, partition=Partition(partition_num=1), schema=schema
|
||||||
|
)
|
||||||
|
logger.info(f"Created table: {name}")
|
||||||
|
|
||||||
|
# Wait for table to be ready
|
||||||
|
while True:
|
||||||
|
time.sleep(2)
|
||||||
|
table = self._database.describe_table(name)
|
||||||
|
if table.state == TableState.NORMAL:
|
||||||
|
logger.info(f"Table {name} is ready.")
|
||||||
|
break
|
||||||
|
logger.info(f"Waiting for table {name} to be ready, current state: {table.state}")
|
||||||
|
self._table = table
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating table: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def insert(self, vectors, payloads=None, ids=None):
|
||||||
|
"""Insert vectors into the table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vectors (List[List[float]]): List of vectors to insert.
|
||||||
|
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||||
|
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||||
|
"""
|
||||||
|
# Prepare data for insertion
|
||||||
|
for idx, vector, metadata in zip(ids, vectors, payloads):
|
||||||
|
row = Row(id=idx, vector=vector, metadata=metadata)
|
||||||
|
self._table.upsert(rows=[row])
|
||||||
|
|
||||||
|
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
|
||||||
|
"""
|
||||||
|
Search for similar vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): Query string.
|
||||||
|
vectors (List[float]): Query vector.
|
||||||
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
|
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Search results.
|
||||||
|
"""
|
||||||
|
# Add filters if provided
|
||||||
|
search_filter = None
|
||||||
|
if filters:
|
||||||
|
search_filter = self._create_filter(filters)
|
||||||
|
|
||||||
|
# Create AnnSearch for vector search
|
||||||
|
request = VectorTopkSearchRequest(
|
||||||
|
vector_field="vector",
|
||||||
|
vector=FloatVector(vectors),
|
||||||
|
limit=limit,
|
||||||
|
filter=search_filter,
|
||||||
|
config=VectorSearchConfig(ef=200),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform search
|
||||||
|
projections = ["id", "metadata"]
|
||||||
|
res = self._table.vector_search(request=request, projections=projections)
|
||||||
|
|
||||||
|
# Parse results
|
||||||
|
output = []
|
||||||
|
for row in res.rows:
|
||||||
|
row_data = row.get("row", {})
|
||||||
|
output_data = OutputData(
|
||||||
|
id=row_data.get("id"), score=row.get("score", 0.0), payload=row_data.get("metadata", {})
|
||||||
|
)
|
||||||
|
output.append(output_data)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def delete(self, vector_id):
|
||||||
|
"""
|
||||||
|
Delete a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to delete.
|
||||||
|
"""
|
||||||
|
self._table.delete(primary_key={"id": vector_id})
|
||||||
|
|
||||||
|
def update(self, vector_id=None, vector=None, payload=None):
|
||||||
|
"""
|
||||||
|
Update a vector and its payload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to update.
|
||||||
|
vector (List[float], optional): Updated vector.
|
||||||
|
payload (Dict, optional): Updated payload.
|
||||||
|
"""
|
||||||
|
row = Row(id=vector_id, vector=vector, metadata=payload)
|
||||||
|
self._table.upsert(rows=[row])
|
||||||
|
|
||||||
|
def get(self, vector_id):
|
||||||
|
"""
|
||||||
|
Retrieve a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OutputData: Retrieved vector.
|
||||||
|
"""
|
||||||
|
projections = ["id", "metadata"]
|
||||||
|
result = self._table.query(primary_key={"id": vector_id}, projections=projections)
|
||||||
|
row = result.row
|
||||||
|
return OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
|
||||||
|
|
||||||
|
def list_cols(self):
|
||||||
|
"""
|
||||||
|
List all tables (collections).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of table names.
|
||||||
|
"""
|
||||||
|
tables = self._database.list_table()
|
||||||
|
return [table.table_name for table in tables]
|
||||||
|
|
||||||
|
def delete_col(self):
|
||||||
|
"""Delete the table."""
|
||||||
|
try:
|
||||||
|
tables = self._database.list_table()
|
||||||
|
|
||||||
|
# skip drop table if table not exists
|
||||||
|
table_exists = any(table.table_name == self.table_name for table in tables)
|
||||||
|
if not table_exists:
|
||||||
|
logger.info(f"Table {self.table_name} does not exist, skipping deletion")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Delete the table
|
||||||
|
self._database.drop_table(self.table_name)
|
||||||
|
logger.info(f"Initiated deletion of table {self.table_name}")
|
||||||
|
|
||||||
|
# Wait for table to be completely deleted
|
||||||
|
while True:
|
||||||
|
time.sleep(2)
|
||||||
|
try:
|
||||||
|
self._database.describe_table(self.table_name)
|
||||||
|
logger.info(f"Waiting for table {self.table_name} to be deleted...")
|
||||||
|
except ServerError as e:
|
||||||
|
if e.code == ServerErrCode.TABLE_NOT_EXIST:
|
||||||
|
logger.info(f"Table {self.table_name} has been completely deleted")
|
||||||
|
break
|
||||||
|
logger.error(f"Error checking table status: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting table: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def col_info(self):
|
||||||
|
"""
|
||||||
|
Get information about the table.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Table information.
|
||||||
|
"""
|
||||||
|
return self._table.stats()
|
||||||
|
|
||||||
|
def list(self, filters: dict = None, limit: int = 100) -> list:
|
||||||
|
"""
|
||||||
|
List all vectors in the table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters (Dict, optional): Filters to apply to the list.
|
||||||
|
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OutputData]: List of vectors.
|
||||||
|
"""
|
||||||
|
projections = ["id", "metadata"]
|
||||||
|
list_filter = self._create_filter(filters) if filters else None
|
||||||
|
result = self._table.select(filter=list_filter, projections=projections, limit=limit)
|
||||||
|
|
||||||
|
memories = []
|
||||||
|
for row in result.rows:
|
||||||
|
obj = OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
|
||||||
|
memories.append(obj)
|
||||||
|
|
||||||
|
return [memories]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the table by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting table {self.table_name}...")
|
||||||
|
try:
|
||||||
|
self.delete_col()
|
||||||
|
self.create_col(
|
||||||
|
name=self.table_name,
|
||||||
|
vector_size=self.embedding_model_dims,
|
||||||
|
distance=self.metric_type,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error resetting table: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _create_filter(self, filters: dict) -> str:
|
||||||
|
"""
|
||||||
|
Create filter expression for queries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters (dict): Filter conditions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Filter expression.
|
||||||
|
"""
|
||||||
|
conditions = []
|
||||||
|
for key, value in filters.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
conditions.append(f'metadata["{key}"] = "{value}"')
|
||||||
|
else:
|
||||||
|
conditions.append(f'metadata["{key}"] = {value}')
|
||||||
|
return " AND ".join(conditions)
|
||||||
@@ -17,6 +17,7 @@ class VectorStoreConfig(BaseModel):
|
|||||||
"pinecone": "PineconeConfig",
|
"pinecone": "PineconeConfig",
|
||||||
"mongodb": "MongoDBConfig",
|
"mongodb": "MongoDBConfig",
|
||||||
"milvus": "MilvusDBConfig",
|
"milvus": "MilvusDBConfig",
|
||||||
|
"baidu": "BaiduDBConfig",
|
||||||
"upstash_vector": "UpstashVectorConfig",
|
"upstash_vector": "UpstashVectorConfig",
|
||||||
"azure_ai_search": "AzureAISearchConfig",
|
"azure_ai_search": "AzureAISearchConfig",
|
||||||
"redis": "RedisDBConfig",
|
"redis": "RedisDBConfig",
|
||||||
|
|||||||
233
tests/vector_stores/test_baidu.py
Normal file
233
tests/vector_stores/test_baidu.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
from unittest.mock import Mock, patch, PropertyMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mem0.vector_stores.baidu import BaiduDB, OutputData
|
||||||
|
from pymochow.model.enum import MetricType, TableState, ServerErrCode
|
||||||
|
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
|
||||||
|
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
|
||||||
|
from pymochow.exception import ServerError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_mochow_client():
|
||||||
|
with patch("pymochow.MochowClient") as mock_client:
|
||||||
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_configuration():
|
||||||
|
with patch("pymochow.configuration.Configuration") as mock_config:
|
||||||
|
yield mock_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_bce_credentials():
|
||||||
|
with patch("pymochow.auth.bce_credentials.BceCredentials") as mock_creds:
|
||||||
|
yield mock_creds
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_table():
|
||||||
|
mock_table = Mock(spec=Table)
|
||||||
|
# 设置 Table 类的属性
|
||||||
|
type(mock_table).database_name = PropertyMock(return_value="test_db")
|
||||||
|
type(mock_table).table_name = PropertyMock(return_value="test_table")
|
||||||
|
type(mock_table).schema = PropertyMock(return_value=Mock())
|
||||||
|
type(mock_table).replication = PropertyMock(return_value=1)
|
||||||
|
type(mock_table).partition = PropertyMock(return_value=Mock())
|
||||||
|
type(mock_table).enable_dynamic_field = PropertyMock(return_value=False)
|
||||||
|
type(mock_table).description = PropertyMock(return_value="")
|
||||||
|
type(mock_table).create_time = PropertyMock(return_value="")
|
||||||
|
type(mock_table).state = PropertyMock(return_value=TableState.NORMAL)
|
||||||
|
type(mock_table).aliases = PropertyMock(return_value=[])
|
||||||
|
return mock_table
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mochow_instance(mock_mochow_client, mock_configuration, mock_bce_credentials, mock_table):
|
||||||
|
mock_database = Mock()
|
||||||
|
mock_client_instance = Mock()
|
||||||
|
|
||||||
|
# Mock the client creation
|
||||||
|
mock_mochow_client.return_value = mock_client_instance
|
||||||
|
|
||||||
|
# Mock database operations
|
||||||
|
mock_client_instance.list_databases.return_value = []
|
||||||
|
mock_client_instance.create_database.return_value = mock_database
|
||||||
|
mock_client_instance.database.return_value = mock_database
|
||||||
|
|
||||||
|
# Mock table operations
|
||||||
|
mock_database.list_table.return_value = []
|
||||||
|
mock_database.create_table.return_value = mock_table
|
||||||
|
mock_database.describe_table.return_value = Mock(state=TableState.NORMAL)
|
||||||
|
mock_database.table.return_value = mock_table
|
||||||
|
|
||||||
|
return BaiduDB(
|
||||||
|
endpoint="http://localhost:8287",
|
||||||
|
account="test_account",
|
||||||
|
api_key="test_api_key",
|
||||||
|
database_name="test_db",
|
||||||
|
table_name="test_table",
|
||||||
|
embedding_model_dims=128,
|
||||||
|
metric_type="COSINE",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert(mochow_instance, mock_mochow_client):
|
||||||
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||||
|
ids = ["id1", "id2"]
|
||||||
|
|
||||||
|
mochow_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
|
|
||||||
|
# Verify table.upsert was called with correct data
|
||||||
|
assert mochow_instance._table.upsert.call_count == 2
|
||||||
|
calls = mochow_instance._table.upsert.call_args_list
|
||||||
|
|
||||||
|
# Check first call
|
||||||
|
first_row = calls[0][1]["rows"][0]
|
||||||
|
assert first_row._data["id"] == "id1"
|
||||||
|
assert first_row._data["vector"] == [0.1, 0.2, 0.3]
|
||||||
|
assert first_row._data["metadata"] == {"name": "vector1"}
|
||||||
|
|
||||||
|
# Check second call
|
||||||
|
second_row = calls[1][1]["rows"][0]
|
||||||
|
assert second_row._data["id"] == "id2"
|
||||||
|
assert second_row._data["vector"] == [0.4, 0.5, 0.6]
|
||||||
|
assert second_row._data["metadata"] == {"name": "vector2"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_search(mochow_instance, mock_mochow_client):
|
||||||
|
# Mock search results
|
||||||
|
mock_search_results = Mock()
|
||||||
|
mock_search_results.rows = [
|
||||||
|
{"row": {"id": "id1", "metadata": {"name": "vector1"}}, "score": 0.1},
|
||||||
|
{"row": {"id": "id2", "metadata": {"name": "vector2"}}, "score": 0.2},
|
||||||
|
]
|
||||||
|
mochow_instance._table.vector_search.return_value = mock_search_results
|
||||||
|
|
||||||
|
vectors = [0.1, 0.2, 0.3]
|
||||||
|
results = mochow_instance.search(query="test", vectors=vectors, limit=2)
|
||||||
|
|
||||||
|
# Verify search was called with correct parameters
|
||||||
|
mochow_instance._table.vector_search.assert_called_once()
|
||||||
|
call_args = mochow_instance._table.vector_search.call_args
|
||||||
|
request = call_args[0][0] if call_args[0] else call_args[1]["request"]
|
||||||
|
|
||||||
|
assert isinstance(request, VectorTopkSearchRequest)
|
||||||
|
assert request._vector_field == "vector"
|
||||||
|
assert isinstance(request._vector, FloatVector)
|
||||||
|
assert request._vector._floats == vectors
|
||||||
|
assert request._limit == 2
|
||||||
|
assert isinstance(request._config, VectorSearchConfig)
|
||||||
|
assert request._config._ef == 200
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0].id == "id1"
|
||||||
|
assert results[0].score == 0.1
|
||||||
|
assert results[0].payload == {"name": "vector1"}
|
||||||
|
assert results[1].id == "id2"
|
||||||
|
assert results[1].score == 0.2
|
||||||
|
assert results[1].payload == {"name": "vector2"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_with_filters(mochow_instance, mock_mochow_client):
|
||||||
|
mochow_instance._table.vector_search.return_value = Mock(rows=[])
|
||||||
|
|
||||||
|
vectors = [0.1, 0.2, 0.3]
|
||||||
|
filters = {"user_id": "user123", "agent_id": "agent456"}
|
||||||
|
|
||||||
|
mochow_instance.search(query="test", vectors=vectors, limit=2, filters=filters)
|
||||||
|
|
||||||
|
# Verify search was called with filter
|
||||||
|
call_args = mochow_instance._table.vector_search.call_args
|
||||||
|
request = call_args[0][0] if call_args[0] else call_args[1]["request"]
|
||||||
|
|
||||||
|
assert request._filter == 'metadata["user_id"] = "user123" AND metadata["agent_id"] = "agent456"'
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete(mochow_instance, mock_mochow_client):
|
||||||
|
vector_id = "id1"
|
||||||
|
mochow_instance.delete(vector_id=vector_id)
|
||||||
|
|
||||||
|
mochow_instance._table.delete.assert_called_once_with(primary_key={"id": vector_id})
|
||||||
|
|
||||||
|
|
||||||
|
def test_update(mochow_instance, mock_mochow_client):
|
||||||
|
vector_id = "id1"
|
||||||
|
new_vector = [0.7, 0.8, 0.9]
|
||||||
|
new_payload = {"name": "updated_vector"}
|
||||||
|
|
||||||
|
mochow_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload)
|
||||||
|
|
||||||
|
mochow_instance._table.upsert.assert_called_once()
|
||||||
|
call_args = mochow_instance._table.upsert.call_args
|
||||||
|
row = call_args[0][0] if call_args[0] else call_args[1]["rows"][0]
|
||||||
|
|
||||||
|
assert row._data["id"] == vector_id
|
||||||
|
assert row._data["vector"] == new_vector
|
||||||
|
assert row._data["metadata"] == new_payload
|
||||||
|
|
||||||
|
|
||||||
|
def test_get(mochow_instance, mock_mochow_client):
|
||||||
|
# Mock query result
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.row = {"id": "id1", "metadata": {"name": "vector1"}}
|
||||||
|
mochow_instance._table.query.return_value = mock_result
|
||||||
|
|
||||||
|
result = mochow_instance.get(vector_id="id1")
|
||||||
|
|
||||||
|
mochow_instance._table.query.assert_called_once_with(primary_key={"id": "id1"}, projections=["id", "metadata"])
|
||||||
|
|
||||||
|
assert result.id == "id1"
|
||||||
|
assert result.score is None
|
||||||
|
assert result.payload == {"name": "vector1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_list(mochow_instance, mock_mochow_client):
|
||||||
|
# Mock select result
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.rows = [{"id": "id1", "metadata": {"name": "vector1"}}, {"id": "id2", "metadata": {"name": "vector2"}}]
|
||||||
|
mochow_instance._table.select.return_value = mock_result
|
||||||
|
|
||||||
|
results = mochow_instance.list(limit=2)
|
||||||
|
|
||||||
|
mochow_instance._table.select.assert_called_once_with(filter=None, projections=["id", "metadata"], limit=2)
|
||||||
|
|
||||||
|
assert len(results[0]) == 2
|
||||||
|
assert results[0][0].id == "id1"
|
||||||
|
assert results[0][1].id == "id2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_cols(mochow_instance, mock_mochow_client):
|
||||||
|
# Mock table list
|
||||||
|
mock_tables = [
|
||||||
|
Mock(spec=Table, database_name="test_db", table_name="table1"),
|
||||||
|
Mock(spec=Table, database_name="test_db", table_name="table2"),
|
||||||
|
]
|
||||||
|
mochow_instance._database.list_table.return_value = mock_tables
|
||||||
|
|
||||||
|
result = mochow_instance.list_cols()
|
||||||
|
|
||||||
|
assert result == ["table1", "table2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_col_not_exists(mochow_instance, mock_mochow_client):
|
||||||
|
# 使用正确的 ServerErrCode 枚举值
|
||||||
|
mochow_instance._database.drop_table.side_effect = ServerError(
|
||||||
|
"Table not exists", code=ServerErrCode.TABLE_NOT_EXIST
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise exception
|
||||||
|
mochow_instance.delete_col()
|
||||||
|
|
||||||
|
|
||||||
|
def test_col_info(mochow_instance, mock_mochow_client):
|
||||||
|
mock_table_info = {"table_name": "test_table", "fields": []}
|
||||||
|
mochow_instance._table.stats.return_value = mock_table_info
|
||||||
|
|
||||||
|
result = mochow_instance.col_info()
|
||||||
|
|
||||||
|
assert result == mock_table_info
|
||||||
Reference in New Issue
Block a user