Files
t6_mem0/mem0/vector_stores/baidu.py

369 lines
12 KiB
Python

import logging
import time
from typing import Dict, Optional
from pydantic import BaseModel
from mem0.vector_stores.base import VectorStoreBase
try:
import pymochow
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.configuration import Configuration
from pymochow.exception import ServerError
from pymochow.model.enum import (
FieldType,
IndexType,
MetricType,
ServerErrCode,
TableState,
)
from pymochow.model.schema import (
AutoBuildRowCountIncrement,
Field,
FilteringIndex,
HNSWParams,
Schema,
VectorIndex,
)
from pymochow.model.table import (
FloatVector,
Partition,
Row,
VectorSearchConfig,
VectorTopkSearchRequest,
)
except ImportError:
raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.")
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class BaiduDB(VectorStoreBase):
def __init__(
self,
endpoint: str,
account: str,
api_key: str,
database_name: str,
table_name: str,
embedding_model_dims: int,
metric_type: MetricType,
) -> None:
"""Initialize the BaiduDB database.
Args:
endpoint (str): Endpoint URL for Baidu VectorDB.
account (str): Account for Baidu VectorDB.
api_key (str): API Key for Baidu VectorDB.
database_name (str): Name of the database.
table_name (str): Name of the table.
embedding_model_dims (int): Dimensions of the embedding model.
metric_type (MetricType): Metric type for similarity search.
"""
self.endpoint = endpoint
self.account = account
self.api_key = api_key
self.database_name = database_name
self.table_name = table_name
self.embedding_model_dims = embedding_model_dims
self.metric_type = metric_type
# Initialize Mochow client
config = Configuration(credentials=BceCredentials(account, api_key), endpoint=endpoint)
self.client = pymochow.MochowClient(config)
# Ensure database and table exist
self._create_database_if_not_exists()
self.create_col(
name=self.table_name,
vector_size=self.embedding_model_dims,
distance=self.metric_type,
)
def _create_database_if_not_exists(self):
"""Create database if it doesn't exist."""
try:
# Check if database exists
databases = self.client.list_databases()
db_exists = any(db.database_name == self.database_name for db in databases)
if not db_exists:
self._database = self.client.create_database(self.database_name)
logger.info(f"Created database: {self.database_name}")
else:
self._database = self.client.database(self.database_name)
logger.info(f"Database {self.database_name} already exists")
except Exception as e:
logger.error(f"Error creating database: {e}")
raise
def create_col(self, name, vector_size, distance):
"""Create a new table.
Args:
name (str): Name of the table to create.
vector_size (int): Dimension of the vector.
distance (str): Metric type for similarity search.
"""
# Check if table already exists
try:
tables = self._database.list_table()
table_exists = any(table.table_name == name for table in tables)
if table_exists:
logger.info(f"Table {name} already exists. Skipping creation.")
self._table = self._database.describe_table(name)
return
# Convert distance string to MetricType enum
metric_type = None
for k, v in MetricType.__members__.items():
if k == distance:
metric_type = v
if metric_type is None:
raise ValueError(f"Unsupported metric_type: {distance}")
# Define table schema
fields = [
Field(
"id", FieldType.STRING, primary_key=True, partition_key=True, auto_increment=False, not_null=True
),
Field("vector", FieldType.FLOAT_VECTOR, dimension=vector_size),
Field("metadata", FieldType.JSON),
]
# Create vector index
indexes = [
VectorIndex(
index_name="vector_idx",
index_type=IndexType.HNSW,
field="vector",
metric_type=metric_type,
params=HNSWParams(m=16, efconstruction=200),
auto_build=True,
auto_build_index_policy=AutoBuildRowCountIncrement(row_count_increment=10000),
),
FilteringIndex(index_name="metadata_filtering_idx", fields=["metadata"]),
]
schema = Schema(fields=fields, indexes=indexes)
# Create table
self._table = self._database.create_table(
table_name=name, replication=3, partition=Partition(partition_num=1), schema=schema
)
logger.info(f"Created table: {name}")
# Wait for table to be ready
while True:
time.sleep(2)
table = self._database.describe_table(name)
if table.state == TableState.NORMAL:
logger.info(f"Table {name} is ready.")
break
logger.info(f"Waiting for table {name} to be ready, current state: {table.state}")
self._table = table
except Exception as e:
logger.error(f"Error creating table: {e}")
raise
def insert(self, vectors, payloads=None, ids=None):
"""Insert vectors into the table.
Args:
vectors (List[List[float]]): List of vectors to insert.
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
# Prepare data for insertion
for idx, vector, metadata in zip(ids, vectors, payloads):
row = Row(id=idx, vector=vector, metadata=metadata)
self._table.upsert(rows=[row])
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
"""
Search for similar vectors.
Args:
query (str): Query string.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: Search results.
"""
# Add filters if provided
search_filter = None
if filters:
search_filter = self._create_filter(filters)
# Create AnnSearch for vector search
request = VectorTopkSearchRequest(
vector_field="vector",
vector=FloatVector(vectors),
limit=limit,
filter=search_filter,
config=VectorSearchConfig(ef=200),
)
# Perform search
projections = ["id", "metadata"]
res = self._table.vector_search(request=request, projections=projections)
# Parse results
output = []
for row in res.rows:
row_data = row.get("row", {})
output_data = OutputData(
id=row_data.get("id"), score=row.get("score", 0.0), payload=row_data.get("metadata", {})
)
output.append(output_data)
return output
def delete(self, vector_id):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
self._table.delete(primary_key={"id": vector_id})
def update(self, vector_id=None, vector=None, payload=None):
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
row = Row(id=vector_id, vector=vector, metadata=payload)
self._table.upsert(rows=[row])
def get(self, vector_id):
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Retrieved vector.
"""
projections = ["id", "metadata"]
result = self._table.query(primary_key={"id": vector_id}, projections=projections)
row = result.row
return OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
def list_cols(self):
"""
List all tables (collections).
Returns:
List[str]: List of table names.
"""
tables = self._database.list_table()
return [table.table_name for table in tables]
def delete_col(self):
"""Delete the table."""
try:
tables = self._database.list_table()
# skip drop table if table not exists
table_exists = any(table.table_name == self.table_name for table in tables)
if not table_exists:
logger.info(f"Table {self.table_name} does not exist, skipping deletion")
return
# Delete the table
self._database.drop_table(self.table_name)
logger.info(f"Initiated deletion of table {self.table_name}")
# Wait for table to be completely deleted
while True:
time.sleep(2)
try:
self._database.describe_table(self.table_name)
logger.info(f"Waiting for table {self.table_name} to be deleted...")
except ServerError as e:
if e.code == ServerErrCode.TABLE_NOT_EXIST:
logger.info(f"Table {self.table_name} has been completely deleted")
break
logger.error(f"Error checking table status: {e}")
raise
except Exception as e:
logger.error(f"Error deleting table: {e}")
raise
def col_info(self):
"""
Get information about the table.
Returns:
Dict[str, Any]: Table information.
"""
return self._table.stats()
def list(self, filters: dict = None, limit: int = 100) -> list:
"""
List all vectors in the table.
Args:
filters (Dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors.
"""
projections = ["id", "metadata"]
list_filter = self._create_filter(filters) if filters else None
result = self._table.select(filter=list_filter, projections=projections, limit=limit)
memories = []
for row in result.rows:
obj = OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
memories.append(obj)
return [memories]
def reset(self):
"""Reset the table by deleting and recreating it."""
logger.warning(f"Resetting table {self.table_name}...")
try:
self.delete_col()
self.create_col(
name=self.table_name,
vector_size=self.embedding_model_dims,
distance=self.metric_type,
)
except Exception as e:
logger.warning(f"Error resetting table: {e}")
raise
def _create_filter(self, filters: dict) -> str:
"""
Create filter expression for queries.
Args:
filters (dict): Filter conditions.
Returns:
str: Filter expression.
"""
conditions = []
for key, value in filters.items():
if isinstance(value, str):
conditions.append(f'metadata["{key}"] = "{value}"')
else:
conditions.append(f'metadata["{key}"] = {value}')
return " AND ".join(conditions)