369 lines
12 KiB
Python
369 lines
12 KiB
Python
import logging
|
|
import time
|
|
from typing import Dict, Optional
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from mem0.vector_stores.base import VectorStoreBase
|
|
|
|
try:
|
|
import pymochow
|
|
from pymochow.auth.bce_credentials import BceCredentials
|
|
from pymochow.configuration import Configuration
|
|
from pymochow.exception import ServerError
|
|
from pymochow.model.enum import (
|
|
FieldType,
|
|
IndexType,
|
|
MetricType,
|
|
ServerErrCode,
|
|
TableState,
|
|
)
|
|
from pymochow.model.schema import (
|
|
AutoBuildRowCountIncrement,
|
|
Field,
|
|
FilteringIndex,
|
|
HNSWParams,
|
|
Schema,
|
|
VectorIndex,
|
|
)
|
|
from pymochow.model.table import (
|
|
FloatVector,
|
|
Partition,
|
|
Row,
|
|
VectorSearchConfig,
|
|
VectorTopkSearchRequest,
|
|
)
|
|
except ImportError:
|
|
raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OutputData(BaseModel):
|
|
id: Optional[str] # memory id
|
|
score: Optional[float] # distance
|
|
payload: Optional[Dict] # metadata
|
|
|
|
|
|
class BaiduDB(VectorStoreBase):
|
|
def __init__(
|
|
self,
|
|
endpoint: str,
|
|
account: str,
|
|
api_key: str,
|
|
database_name: str,
|
|
table_name: str,
|
|
embedding_model_dims: int,
|
|
metric_type: MetricType,
|
|
) -> None:
|
|
"""Initialize the BaiduDB database.
|
|
|
|
Args:
|
|
endpoint (str): Endpoint URL for Baidu VectorDB.
|
|
account (str): Account for Baidu VectorDB.
|
|
api_key (str): API Key for Baidu VectorDB.
|
|
database_name (str): Name of the database.
|
|
table_name (str): Name of the table.
|
|
embedding_model_dims (int): Dimensions of the embedding model.
|
|
metric_type (MetricType): Metric type for similarity search.
|
|
"""
|
|
self.endpoint = endpoint
|
|
self.account = account
|
|
self.api_key = api_key
|
|
self.database_name = database_name
|
|
self.table_name = table_name
|
|
self.embedding_model_dims = embedding_model_dims
|
|
self.metric_type = metric_type
|
|
|
|
# Initialize Mochow client
|
|
config = Configuration(credentials=BceCredentials(account, api_key), endpoint=endpoint)
|
|
self.client = pymochow.MochowClient(config)
|
|
|
|
# Ensure database and table exist
|
|
self._create_database_if_not_exists()
|
|
self.create_col(
|
|
name=self.table_name,
|
|
vector_size=self.embedding_model_dims,
|
|
distance=self.metric_type,
|
|
)
|
|
|
|
def _create_database_if_not_exists(self):
|
|
"""Create database if it doesn't exist."""
|
|
try:
|
|
# Check if database exists
|
|
databases = self.client.list_databases()
|
|
db_exists = any(db.database_name == self.database_name for db in databases)
|
|
if not db_exists:
|
|
self._database = self.client.create_database(self.database_name)
|
|
logger.info(f"Created database: {self.database_name}")
|
|
else:
|
|
self._database = self.client.database(self.database_name)
|
|
logger.info(f"Database {self.database_name} already exists")
|
|
except Exception as e:
|
|
logger.error(f"Error creating database: {e}")
|
|
raise
|
|
|
|
def create_col(self, name, vector_size, distance):
|
|
"""Create a new table.
|
|
|
|
Args:
|
|
name (str): Name of the table to create.
|
|
vector_size (int): Dimension of the vector.
|
|
distance (str): Metric type for similarity search.
|
|
"""
|
|
# Check if table already exists
|
|
try:
|
|
tables = self._database.list_table()
|
|
table_exists = any(table.table_name == name for table in tables)
|
|
if table_exists:
|
|
logger.info(f"Table {name} already exists. Skipping creation.")
|
|
self._table = self._database.describe_table(name)
|
|
return
|
|
|
|
# Convert distance string to MetricType enum
|
|
metric_type = None
|
|
for k, v in MetricType.__members__.items():
|
|
if k == distance:
|
|
metric_type = v
|
|
if metric_type is None:
|
|
raise ValueError(f"Unsupported metric_type: {distance}")
|
|
|
|
# Define table schema
|
|
fields = [
|
|
Field(
|
|
"id", FieldType.STRING, primary_key=True, partition_key=True, auto_increment=False, not_null=True
|
|
),
|
|
Field("vector", FieldType.FLOAT_VECTOR, dimension=vector_size),
|
|
Field("metadata", FieldType.JSON),
|
|
]
|
|
|
|
# Create vector index
|
|
indexes = [
|
|
VectorIndex(
|
|
index_name="vector_idx",
|
|
index_type=IndexType.HNSW,
|
|
field="vector",
|
|
metric_type=metric_type,
|
|
params=HNSWParams(m=16, efconstruction=200),
|
|
auto_build=True,
|
|
auto_build_index_policy=AutoBuildRowCountIncrement(row_count_increment=10000),
|
|
),
|
|
FilteringIndex(index_name="metadata_filtering_idx", fields=["metadata"]),
|
|
]
|
|
|
|
schema = Schema(fields=fields, indexes=indexes)
|
|
|
|
# Create table
|
|
self._table = self._database.create_table(
|
|
table_name=name, replication=3, partition=Partition(partition_num=1), schema=schema
|
|
)
|
|
logger.info(f"Created table: {name}")
|
|
|
|
# Wait for table to be ready
|
|
while True:
|
|
time.sleep(2)
|
|
table = self._database.describe_table(name)
|
|
if table.state == TableState.NORMAL:
|
|
logger.info(f"Table {name} is ready.")
|
|
break
|
|
logger.info(f"Waiting for table {name} to be ready, current state: {table.state}")
|
|
self._table = table
|
|
except Exception as e:
|
|
logger.error(f"Error creating table: {e}")
|
|
raise
|
|
|
|
def insert(self, vectors, payloads=None, ids=None):
|
|
"""Insert vectors into the table.
|
|
|
|
Args:
|
|
vectors (List[List[float]]): List of vectors to insert.
|
|
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
|
ids (List[str], optional): List of IDs corresponding to vectors.
|
|
"""
|
|
# Prepare data for insertion
|
|
for idx, vector, metadata in zip(ids, vectors, payloads):
|
|
row = Row(id=idx, vector=vector, metadata=metadata)
|
|
self._table.upsert(rows=[row])
|
|
|
|
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
|
|
"""
|
|
Search for similar vectors.
|
|
|
|
Args:
|
|
query (str): Query string.
|
|
vectors (List[float]): Query vector.
|
|
limit (int, optional): Number of results to return. Defaults to 5.
|
|
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
|
|
|
Returns:
|
|
list: Search results.
|
|
"""
|
|
# Add filters if provided
|
|
search_filter = None
|
|
if filters:
|
|
search_filter = self._create_filter(filters)
|
|
|
|
# Create AnnSearch for vector search
|
|
request = VectorTopkSearchRequest(
|
|
vector_field="vector",
|
|
vector=FloatVector(vectors),
|
|
limit=limit,
|
|
filter=search_filter,
|
|
config=VectorSearchConfig(ef=200),
|
|
)
|
|
|
|
# Perform search
|
|
projections = ["id", "metadata"]
|
|
res = self._table.vector_search(request=request, projections=projections)
|
|
|
|
# Parse results
|
|
output = []
|
|
for row in res.rows:
|
|
row_data = row.get("row", {})
|
|
output_data = OutputData(
|
|
id=row_data.get("id"), score=row.get("score", 0.0), payload=row_data.get("metadata", {})
|
|
)
|
|
output.append(output_data)
|
|
|
|
return output
|
|
|
|
def delete(self, vector_id):
|
|
"""
|
|
Delete a vector by ID.
|
|
|
|
Args:
|
|
vector_id (str): ID of the vector to delete.
|
|
"""
|
|
self._table.delete(primary_key={"id": vector_id})
|
|
|
|
def update(self, vector_id=None, vector=None, payload=None):
|
|
"""
|
|
Update a vector and its payload.
|
|
|
|
Args:
|
|
vector_id (str): ID of the vector to update.
|
|
vector (List[float], optional): Updated vector.
|
|
payload (Dict, optional): Updated payload.
|
|
"""
|
|
row = Row(id=vector_id, vector=vector, metadata=payload)
|
|
self._table.upsert(rows=[row])
|
|
|
|
def get(self, vector_id):
|
|
"""
|
|
Retrieve a vector by ID.
|
|
|
|
Args:
|
|
vector_id (str): ID of the vector to retrieve.
|
|
|
|
Returns:
|
|
OutputData: Retrieved vector.
|
|
"""
|
|
projections = ["id", "metadata"]
|
|
result = self._table.query(primary_key={"id": vector_id}, projections=projections)
|
|
row = result.row
|
|
return OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
|
|
|
|
def list_cols(self):
|
|
"""
|
|
List all tables (collections).
|
|
|
|
Returns:
|
|
List[str]: List of table names.
|
|
"""
|
|
tables = self._database.list_table()
|
|
return [table.table_name for table in tables]
|
|
|
|
def delete_col(self):
|
|
"""Delete the table."""
|
|
try:
|
|
tables = self._database.list_table()
|
|
|
|
# skip drop table if table not exists
|
|
table_exists = any(table.table_name == self.table_name for table in tables)
|
|
if not table_exists:
|
|
logger.info(f"Table {self.table_name} does not exist, skipping deletion")
|
|
return
|
|
|
|
# Delete the table
|
|
self._database.drop_table(self.table_name)
|
|
logger.info(f"Initiated deletion of table {self.table_name}")
|
|
|
|
# Wait for table to be completely deleted
|
|
while True:
|
|
time.sleep(2)
|
|
try:
|
|
self._database.describe_table(self.table_name)
|
|
logger.info(f"Waiting for table {self.table_name} to be deleted...")
|
|
except ServerError as e:
|
|
if e.code == ServerErrCode.TABLE_NOT_EXIST:
|
|
logger.info(f"Table {self.table_name} has been completely deleted")
|
|
break
|
|
logger.error(f"Error checking table status: {e}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error deleting table: {e}")
|
|
raise
|
|
|
|
def col_info(self):
|
|
"""
|
|
Get information about the table.
|
|
|
|
Returns:
|
|
Dict[str, Any]: Table information.
|
|
"""
|
|
return self._table.stats()
|
|
|
|
def list(self, filters: dict = None, limit: int = 100) -> list:
|
|
"""
|
|
List all vectors in the table.
|
|
|
|
Args:
|
|
filters (Dict, optional): Filters to apply to the list.
|
|
limit (int, optional): Number of vectors to return. Defaults to 100.
|
|
|
|
Returns:
|
|
List[OutputData]: List of vectors.
|
|
"""
|
|
projections = ["id", "metadata"]
|
|
list_filter = self._create_filter(filters) if filters else None
|
|
result = self._table.select(filter=list_filter, projections=projections, limit=limit)
|
|
|
|
memories = []
|
|
for row in result.rows:
|
|
obj = OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
|
|
memories.append(obj)
|
|
|
|
return [memories]
|
|
|
|
def reset(self):
|
|
"""Reset the table by deleting and recreating it."""
|
|
logger.warning(f"Resetting table {self.table_name}...")
|
|
try:
|
|
self.delete_col()
|
|
self.create_col(
|
|
name=self.table_name,
|
|
vector_size=self.embedding_model_dims,
|
|
distance=self.metric_type,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Error resetting table: {e}")
|
|
raise
|
|
|
|
def _create_filter(self, filters: dict) -> str:
|
|
"""
|
|
Create filter expression for queries.
|
|
|
|
Args:
|
|
filters (dict): Filter conditions.
|
|
|
|
Returns:
|
|
str: Filter expression.
|
|
"""
|
|
conditions = []
|
|
for key, value in filters.items():
|
|
if isinstance(value, str):
|
|
conditions.append(f'metadata["{key}"] = "{value}"')
|
|
else:
|
|
conditions.append(f'metadata["{key}"] = {value}')
|
|
return " AND ".join(conditions)
|