Feature: baidu vector db integration (#2929)
This commit is contained in:
2
Makefile
2
Makefile
@@ -13,7 +13,7 @@ install:
|
||||
install_all:
|
||||
pip install ruff==0.6.9 groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
||||
google-generativeai elasticsearch opensearch-py vecs "pinecone<7.0.0" pinecone-text faiss-cpu langchain-community \
|
||||
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25
|
||||
upstash-vector azure-search-documents langchain-memgraph langchain-neo4j rank-bm25 pymochow
|
||||
|
||||
# Format code with ruff
|
||||
format:
|
||||
|
||||
67
docs/components/vectordbs/dbs/baidu.mdx
Normal file
67
docs/components/vectordbs/dbs/baidu.mdx
Normal file
@@ -0,0 +1,67 @@
|
||||
---
|
||||
title: Baidu VectorDB (Mochow)
|
||||
---
|
||||
|
||||
[Baidu VectorDB](https://cloud.baidu.com/doc/VDB/index.html) is an enterprise-level distributed vector database service developed by Baidu Intelligent Cloud. It is powered by Baidu's proprietary "Mochow" vector database kernel, providing high performance, availability, and security for vector search.
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "baidu",
|
||||
"config": {
|
||||
"endpoint": "http://your-mochow-endpoint:8287",
|
||||
"account": "root",
|
||||
"api_key": "your-api-key",
|
||||
"database_name": "mem0",
|
||||
"table_name": "mem0_table",
|
||||
"embedding_model_dims": 1536,
|
||||
"metric_type": "COSINE"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
messages = [
|
||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||
{"role": "assistant", "content": "How about a thriller movie? They can be quite engaging."},
|
||||
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||
]
|
||||
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||
```
|
||||
|
||||
### Config
|
||||
|
||||
Here are the available parameters for the `mochow` config:
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| --- | --- | --- |
|
||||
| `endpoint` | Endpoint URL for your Baidu VectorDB instance | Required |
|
||||
| `account` | Baidu VectorDB account name | `root` |
|
||||
| `api_key` | API key for accessing Baidu VectorDB | Required |
|
||||
| `database_name` | Name of the database | `mem0` |
|
||||
| `table_name` | Name of the table | `mem0_table` |
|
||||
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
||||
| `metric_type` | Distance metric for similarity search | `L2` |
|
||||
|
||||
### Distance Metrics
|
||||
|
||||
The following distance metrics are supported:
|
||||
|
||||
- `L2`: Euclidean distance (default)
|
||||
- `IP`: Inner product
|
||||
- `COSINE`: Cosine similarity
|
||||
|
||||
### Index Configuration
|
||||
|
||||
The vector index is automatically configured with the following HNSW parameters:
|
||||
|
||||
- `m`: 16 (number of connections per element)
|
||||
- `efconstruction`: 200 (size of the dynamic candidate list)
|
||||
- `auto_build`: true (automatically build index)
|
||||
- `auto_build_index_policy`: Incremental build with 10000 rows increment
|
||||
@@ -146,7 +146,8 @@
|
||||
"components/vectordbs/dbs/vertex_ai",
|
||||
"components/vectordbs/dbs/weaviate",
|
||||
"components/vectordbs/dbs/faiss",
|
||||
"components/vectordbs/dbs/langchain"
|
||||
"components/vectordbs/dbs/langchain",
|
||||
"components/vectordbs/dbs/baidu"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
30
mem0/configs/vector_stores/baidu.py
Normal file
30
mem0/configs/vector_stores/baidu.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class BaiduDBConfig(BaseModel):
|
||||
endpoint: str = Field("http://localhost:8287", description="Endpoint URL for Baidu VectorDB")
|
||||
account: str = Field("root", description="Account for Baidu VectorDB")
|
||||
api_key: str = Field(None, description="API Key for Baidu VectorDB")
|
||||
database_name: str = Field("mem0", description="Name of the database")
|
||||
table_name: str = Field("mem0", description="Name of the table")
|
||||
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
||||
metric_type: str = Field("L2", description="Metric type for similarity search")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
349
mem0/vector_stores/baidu.py
Normal file
349
mem0/vector_stores/baidu.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
try:
|
||||
import pymochow
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.auth.bce_credentials import BceCredentials
|
||||
from pymochow.model.enum import FieldType, MetricType, IndexType, TableState, ServerErrCode
|
||||
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
|
||||
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector
|
||||
from pymochow.exception import ServerError
|
||||
except ImportError:
|
||||
raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
class BaiduDB(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
account: str,
|
||||
api_key: str,
|
||||
database_name: str,
|
||||
table_name: str,
|
||||
embedding_model_dims: int,
|
||||
metric_type: MetricType,
|
||||
) -> None:
|
||||
"""Initialize the BaiduDB database.
|
||||
|
||||
Args:
|
||||
endpoint (str): Endpoint URL for Baidu VectorDB.
|
||||
account (str): Account for Baidu VectorDB.
|
||||
api_key (str): API Key for Baidu VectorDB.
|
||||
database_name (str): Name of the database.
|
||||
table_name (str): Name of the table.
|
||||
embedding_model_dims (int): Dimensions of the embedding model.
|
||||
metric_type (MetricType): Metric type for similarity search.
|
||||
"""
|
||||
self.endpoint = endpoint
|
||||
self.account = account
|
||||
self.api_key = api_key
|
||||
self.database_name = database_name
|
||||
self.table_name = table_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.metric_type = metric_type
|
||||
|
||||
# Initialize Mochow client
|
||||
config = Configuration(credentials=BceCredentials(account, api_key), endpoint=endpoint)
|
||||
self.client = pymochow.MochowClient(config)
|
||||
|
||||
# Ensure database and table exist
|
||||
self._create_database_if_not_exists()
|
||||
self.create_col(
|
||||
name=self.table_name,
|
||||
vector_size=self.embedding_model_dims,
|
||||
distance=self.metric_type,
|
||||
)
|
||||
|
||||
def _create_database_if_not_exists(self):
|
||||
"""Create database if it doesn't exist."""
|
||||
try:
|
||||
# Check if database exists
|
||||
databases = self.client.list_databases()
|
||||
db_exists = any(db.database_name == self.database_name for db in databases)
|
||||
if not db_exists:
|
||||
self._database = self.client.create_database(self.database_name)
|
||||
logger.info(f"Created database: {self.database_name}")
|
||||
else:
|
||||
self._database = self.client.database(self.database_name)
|
||||
logger.info(f"Database {self.database_name} already exists")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating database: {e}")
|
||||
raise
|
||||
|
||||
def create_col(self, name, vector_size, distance):
|
||||
"""Create a new table.
|
||||
|
||||
Args:
|
||||
name (str): Name of the table to create.
|
||||
vector_size (int): Dimension of the vector.
|
||||
distance (str): Metric type for similarity search.
|
||||
"""
|
||||
# Check if table already exists
|
||||
try:
|
||||
tables = self._database.list_table()
|
||||
table_exists = any(table.table_name == name for table in tables)
|
||||
if table_exists:
|
||||
logger.info(f"Table {name} already exists. Skipping creation.")
|
||||
self._table = self._database.describe_table(name)
|
||||
return
|
||||
|
||||
# Convert distance string to MetricType enum
|
||||
metric_type = None
|
||||
for k, v in MetricType.__members__.items():
|
||||
if k == distance:
|
||||
metric_type = v
|
||||
if metric_type is None:
|
||||
raise ValueError(f"Unsupported metric_type: {distance}")
|
||||
|
||||
# Define table schema
|
||||
fields = [
|
||||
Field(
|
||||
"id", FieldType.STRING, primary_key=True, partition_key=True, auto_increment=False, not_null=True
|
||||
),
|
||||
Field("vector", FieldType.FLOAT_VECTOR, dimension=vector_size),
|
||||
Field("metadata", FieldType.JSON),
|
||||
]
|
||||
|
||||
# Create vector index
|
||||
indexes = [
|
||||
VectorIndex(
|
||||
index_name="vector_idx",
|
||||
index_type=IndexType.HNSW,
|
||||
field="vector",
|
||||
metric_type=metric_type,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
auto_build=True,
|
||||
auto_build_index_policy=AutoBuildRowCountIncrement(row_count_increment=10000),
|
||||
),
|
||||
FilteringIndex(index_name="metadata_filtering_idx", fields=["metadata"]),
|
||||
]
|
||||
|
||||
schema = Schema(fields=fields, indexes=indexes)
|
||||
|
||||
# Create table
|
||||
self._table = self._database.create_table(
|
||||
table_name=name, replication=3, partition=Partition(partition_num=1), schema=schema
|
||||
)
|
||||
logger.info(f"Created table: {name}")
|
||||
|
||||
# Wait for table to be ready
|
||||
while True:
|
||||
time.sleep(2)
|
||||
table = self._database.describe_table(name)
|
||||
if table.state == TableState.NORMAL:
|
||||
logger.info(f"Table {name} is ready.")
|
||||
break
|
||||
logger.info(f"Waiting for table {name} to be ready, current state: {table.state}")
|
||||
self._table = table
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating table: {e}")
|
||||
raise
|
||||
|
||||
def insert(self, vectors, payloads=None, ids=None):
|
||||
"""Insert vectors into the table.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert.
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
# Prepare data for insertion
|
||||
for idx, vector, metadata in zip(ids, vectors, payloads):
|
||||
row = Row(id=idx, vector=vector, metadata=metadata)
|
||||
self._table.upsert(rows=[row])
|
||||
|
||||
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (str): Query string.
|
||||
vectors (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: Search results.
|
||||
"""
|
||||
# Add filters if provided
|
||||
search_filter = None
|
||||
if filters:
|
||||
search_filter = self._create_filter(filters)
|
||||
|
||||
# Create AnnSearch for vector search
|
||||
request = VectorTopkSearchRequest(
|
||||
vector_field="vector",
|
||||
vector=FloatVector(vectors),
|
||||
limit=limit,
|
||||
filter=search_filter,
|
||||
config=VectorSearchConfig(ef=200),
|
||||
)
|
||||
|
||||
# Perform search
|
||||
projections = ["id", "metadata"]
|
||||
res = self._table.vector_search(request=request, projections=projections)
|
||||
|
||||
# Parse results
|
||||
output = []
|
||||
for row in res.rows:
|
||||
row_data = row.get("row", {})
|
||||
output_data = OutputData(
|
||||
id=row_data.get("id"), score=row.get("score", 0.0), payload=row_data.get("metadata", {})
|
||||
)
|
||||
output.append(output_data)
|
||||
|
||||
return output
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
self._table.delete(primary_key={"id": vector_id})
|
||||
|
||||
def update(self, vector_id=None, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (List[float], optional): Updated vector.
|
||||
payload (Dict, optional): Updated payload.
|
||||
"""
|
||||
row = Row(id=vector_id, vector=vector, metadata=payload)
|
||||
self._table.upsert(rows=[row])
|
||||
|
||||
def get(self, vector_id):
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
OutputData: Retrieved vector.
|
||||
"""
|
||||
projections = ["id", "metadata"]
|
||||
result = self._table.query(primary_key={"id": vector_id}, projections=projections)
|
||||
row = result.row
|
||||
return OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all tables (collections).
|
||||
|
||||
Returns:
|
||||
List[str]: List of table names.
|
||||
"""
|
||||
tables = self._database.list_table()
|
||||
return [table.table_name for table in tables]
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete the table."""
|
||||
try:
|
||||
tables = self._database.list_table()
|
||||
|
||||
# skip drop table if table not exists
|
||||
table_exists = any(table.table_name == self.table_name for table in tables)
|
||||
if not table_exists:
|
||||
logger.info(f"Table {self.table_name} does not exist, skipping deletion")
|
||||
return
|
||||
|
||||
# Delete the table
|
||||
self._database.drop_table(self.table_name)
|
||||
logger.info(f"Initiated deletion of table {self.table_name}")
|
||||
|
||||
# Wait for table to be completely deleted
|
||||
while True:
|
||||
time.sleep(2)
|
||||
try:
|
||||
self._database.describe_table(self.table_name)
|
||||
logger.info(f"Waiting for table {self.table_name} to be deleted...")
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.TABLE_NOT_EXIST:
|
||||
logger.info(f"Table {self.table_name} has been completely deleted")
|
||||
break
|
||||
logger.error(f"Error checking table status: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting table: {e}")
|
||||
raise
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
Get information about the table.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Table information.
|
||||
"""
|
||||
return self._table.stats()
|
||||
|
||||
def list(self, filters: dict = None, limit: int = 100) -> list:
|
||||
"""
|
||||
List all vectors in the table.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply to the list.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
projections = ["id", "metadata"]
|
||||
list_filter = self._create_filter(filters) if filters else None
|
||||
result = self._table.select(filter=list_filter, projections=projections, limit=limit)
|
||||
|
||||
memories = []
|
||||
for row in result.rows:
|
||||
obj = OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
|
||||
memories.append(obj)
|
||||
|
||||
return [memories]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the table by deleting and recreating it."""
|
||||
logger.warning(f"Resetting table {self.table_name}...")
|
||||
try:
|
||||
self.delete_col()
|
||||
self.create_col(
|
||||
name=self.table_name,
|
||||
vector_size=self.embedding_model_dims,
|
||||
distance=self.metric_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error resetting table: {e}")
|
||||
raise
|
||||
|
||||
def _create_filter(self, filters: dict) -> str:
|
||||
"""
|
||||
Create filter expression for queries.
|
||||
|
||||
Args:
|
||||
filters (dict): Filter conditions.
|
||||
|
||||
Returns:
|
||||
str: Filter expression.
|
||||
"""
|
||||
conditions = []
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, str):
|
||||
conditions.append(f'metadata["{key}"] = "{value}"')
|
||||
else:
|
||||
conditions.append(f'metadata["{key}"] = {value}')
|
||||
return " AND ".join(conditions)
|
||||
@@ -17,6 +17,7 @@ class VectorStoreConfig(BaseModel):
|
||||
"pinecone": "PineconeConfig",
|
||||
"mongodb": "MongoDBConfig",
|
||||
"milvus": "MilvusDBConfig",
|
||||
"baidu": "BaiduDBConfig",
|
||||
"upstash_vector": "UpstashVectorConfig",
|
||||
"azure_ai_search": "AzureAISearchConfig",
|
||||
"redis": "RedisDBConfig",
|
||||
|
||||
233
tests/vector_stores/test_baidu.py
Normal file
233
tests/vector_stores/test_baidu.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from unittest.mock import Mock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0.vector_stores.baidu import BaiduDB, OutputData
|
||||
from pymochow.model.enum import MetricType, TableState, ServerErrCode
|
||||
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
|
||||
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
|
||||
from pymochow.exception import ServerError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mochow_client():
|
||||
with patch("pymochow.MochowClient") as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_configuration():
|
||||
with patch("pymochow.configuration.Configuration") as mock_config:
|
||||
yield mock_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bce_credentials():
|
||||
with patch("pymochow.auth.bce_credentials.BceCredentials") as mock_creds:
|
||||
yield mock_creds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_table():
|
||||
mock_table = Mock(spec=Table)
|
||||
# 设置 Table 类的属性
|
||||
type(mock_table).database_name = PropertyMock(return_value="test_db")
|
||||
type(mock_table).table_name = PropertyMock(return_value="test_table")
|
||||
type(mock_table).schema = PropertyMock(return_value=Mock())
|
||||
type(mock_table).replication = PropertyMock(return_value=1)
|
||||
type(mock_table).partition = PropertyMock(return_value=Mock())
|
||||
type(mock_table).enable_dynamic_field = PropertyMock(return_value=False)
|
||||
type(mock_table).description = PropertyMock(return_value="")
|
||||
type(mock_table).create_time = PropertyMock(return_value="")
|
||||
type(mock_table).state = PropertyMock(return_value=TableState.NORMAL)
|
||||
type(mock_table).aliases = PropertyMock(return_value=[])
|
||||
return mock_table
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mochow_instance(mock_mochow_client, mock_configuration, mock_bce_credentials, mock_table):
|
||||
mock_database = Mock()
|
||||
mock_client_instance = Mock()
|
||||
|
||||
# Mock the client creation
|
||||
mock_mochow_client.return_value = mock_client_instance
|
||||
|
||||
# Mock database operations
|
||||
mock_client_instance.list_databases.return_value = []
|
||||
mock_client_instance.create_database.return_value = mock_database
|
||||
mock_client_instance.database.return_value = mock_database
|
||||
|
||||
# Mock table operations
|
||||
mock_database.list_table.return_value = []
|
||||
mock_database.create_table.return_value = mock_table
|
||||
mock_database.describe_table.return_value = Mock(state=TableState.NORMAL)
|
||||
mock_database.table.return_value = mock_table
|
||||
|
||||
return BaiduDB(
|
||||
endpoint="http://localhost:8287",
|
||||
account="test_account",
|
||||
api_key="test_api_key",
|
||||
database_name="test_db",
|
||||
table_name="test_table",
|
||||
embedding_model_dims=128,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
|
||||
def test_insert(mochow_instance, mock_mochow_client):
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||
ids = ["id1", "id2"]
|
||||
|
||||
mochow_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
# Verify table.upsert was called with correct data
|
||||
assert mochow_instance._table.upsert.call_count == 2
|
||||
calls = mochow_instance._table.upsert.call_args_list
|
||||
|
||||
# Check first call
|
||||
first_row = calls[0][1]["rows"][0]
|
||||
assert first_row._data["id"] == "id1"
|
||||
assert first_row._data["vector"] == [0.1, 0.2, 0.3]
|
||||
assert first_row._data["metadata"] == {"name": "vector1"}
|
||||
|
||||
# Check second call
|
||||
second_row = calls[1][1]["rows"][0]
|
||||
assert second_row._data["id"] == "id2"
|
||||
assert second_row._data["vector"] == [0.4, 0.5, 0.6]
|
||||
assert second_row._data["metadata"] == {"name": "vector2"}
|
||||
|
||||
|
||||
def test_search(mochow_instance, mock_mochow_client):
|
||||
# Mock search results
|
||||
mock_search_results = Mock()
|
||||
mock_search_results.rows = [
|
||||
{"row": {"id": "id1", "metadata": {"name": "vector1"}}, "score": 0.1},
|
||||
{"row": {"id": "id2", "metadata": {"name": "vector2"}}, "score": 0.2},
|
||||
]
|
||||
mochow_instance._table.vector_search.return_value = mock_search_results
|
||||
|
||||
vectors = [0.1, 0.2, 0.3]
|
||||
results = mochow_instance.search(query="test", vectors=vectors, limit=2)
|
||||
|
||||
# Verify search was called with correct parameters
|
||||
mochow_instance._table.vector_search.assert_called_once()
|
||||
call_args = mochow_instance._table.vector_search.call_args
|
||||
request = call_args[0][0] if call_args[0] else call_args[1]["request"]
|
||||
|
||||
assert isinstance(request, VectorTopkSearchRequest)
|
||||
assert request._vector_field == "vector"
|
||||
assert isinstance(request._vector, FloatVector)
|
||||
assert request._vector._floats == vectors
|
||||
assert request._limit == 2
|
||||
assert isinstance(request._config, VectorSearchConfig)
|
||||
assert request._config._ef == 200
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 2
|
||||
assert results[0].id == "id1"
|
||||
assert results[0].score == 0.1
|
||||
assert results[0].payload == {"name": "vector1"}
|
||||
assert results[1].id == "id2"
|
||||
assert results[1].score == 0.2
|
||||
assert results[1].payload == {"name": "vector2"}
|
||||
|
||||
|
||||
def test_search_with_filters(mochow_instance, mock_mochow_client):
|
||||
mochow_instance._table.vector_search.return_value = Mock(rows=[])
|
||||
|
||||
vectors = [0.1, 0.2, 0.3]
|
||||
filters = {"user_id": "user123", "agent_id": "agent456"}
|
||||
|
||||
mochow_instance.search(query="test", vectors=vectors, limit=2, filters=filters)
|
||||
|
||||
# Verify search was called with filter
|
||||
call_args = mochow_instance._table.vector_search.call_args
|
||||
request = call_args[0][0] if call_args[0] else call_args[1]["request"]
|
||||
|
||||
assert request._filter == 'metadata["user_id"] = "user123" AND metadata["agent_id"] = "agent456"'
|
||||
|
||||
|
||||
def test_delete(mochow_instance, mock_mochow_client):
|
||||
vector_id = "id1"
|
||||
mochow_instance.delete(vector_id=vector_id)
|
||||
|
||||
mochow_instance._table.delete.assert_called_once_with(primary_key={"id": vector_id})
|
||||
|
||||
|
||||
def test_update(mochow_instance, mock_mochow_client):
|
||||
vector_id = "id1"
|
||||
new_vector = [0.7, 0.8, 0.9]
|
||||
new_payload = {"name": "updated_vector"}
|
||||
|
||||
mochow_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload)
|
||||
|
||||
mochow_instance._table.upsert.assert_called_once()
|
||||
call_args = mochow_instance._table.upsert.call_args
|
||||
row = call_args[0][0] if call_args[0] else call_args[1]["rows"][0]
|
||||
|
||||
assert row._data["id"] == vector_id
|
||||
assert row._data["vector"] == new_vector
|
||||
assert row._data["metadata"] == new_payload
|
||||
|
||||
|
||||
def test_get(mochow_instance, mock_mochow_client):
|
||||
# Mock query result
|
||||
mock_result = Mock()
|
||||
mock_result.row = {"id": "id1", "metadata": {"name": "vector1"}}
|
||||
mochow_instance._table.query.return_value = mock_result
|
||||
|
||||
result = mochow_instance.get(vector_id="id1")
|
||||
|
||||
mochow_instance._table.query.assert_called_once_with(primary_key={"id": "id1"}, projections=["id", "metadata"])
|
||||
|
||||
assert result.id == "id1"
|
||||
assert result.score is None
|
||||
assert result.payload == {"name": "vector1"}
|
||||
|
||||
|
||||
def test_list(mochow_instance, mock_mochow_client):
|
||||
# Mock select result
|
||||
mock_result = Mock()
|
||||
mock_result.rows = [{"id": "id1", "metadata": {"name": "vector1"}}, {"id": "id2", "metadata": {"name": "vector2"}}]
|
||||
mochow_instance._table.select.return_value = mock_result
|
||||
|
||||
results = mochow_instance.list(limit=2)
|
||||
|
||||
mochow_instance._table.select.assert_called_once_with(filter=None, projections=["id", "metadata"], limit=2)
|
||||
|
||||
assert len(results[0]) == 2
|
||||
assert results[0][0].id == "id1"
|
||||
assert results[0][1].id == "id2"
|
||||
|
||||
|
||||
def test_list_cols(mochow_instance, mock_mochow_client):
|
||||
# Mock table list
|
||||
mock_tables = [
|
||||
Mock(spec=Table, database_name="test_db", table_name="table1"),
|
||||
Mock(spec=Table, database_name="test_db", table_name="table2"),
|
||||
]
|
||||
mochow_instance._database.list_table.return_value = mock_tables
|
||||
|
||||
result = mochow_instance.list_cols()
|
||||
|
||||
assert result == ["table1", "table2"]
|
||||
|
||||
|
||||
def test_delete_col_not_exists(mochow_instance, mock_mochow_client):
|
||||
# 使用正确的 ServerErrCode 枚举值
|
||||
mochow_instance._database.drop_table.side_effect = ServerError(
|
||||
"Table not exists", code=ServerErrCode.TABLE_NOT_EXIST
|
||||
)
|
||||
|
||||
# Should not raise exception
|
||||
mochow_instance.delete_col()
|
||||
|
||||
|
||||
def test_col_info(mochow_instance, mock_mochow_client):
|
||||
mock_table_info = {"table_name": "test_table", "fields": []}
|
||||
mochow_instance._table.stats.return_value = mock_table_info
|
||||
|
||||
result = mochow_instance.col_info()
|
||||
|
||||
assert result == mock_table_info
|
||||
Reference in New Issue
Block a user