feat: enhance Azure AI Search Integration with Binary Quantization, Pre/Post Filter Options, and user agent header (#2354)

This commit is contained in:
Farzad Sunavala
2025-03-12 10:50:25 -05:00
committed by GitHub
parent 65f826e064
commit ba9c61938b
5 changed files with 788 additions and 83 deletions

View File

@@ -1,5 +1,6 @@
import json
import logging
import re
from typing import List, Optional
from pydantic import BaseModel
@@ -12,6 +13,7 @@ try:
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
BinaryQuantizationCompression,
HnswAlgorithmConfiguration,
ScalarQuantizationCompression,
SearchField,
@@ -24,7 +26,7 @@ try:
from azure.search.documents.models import VectorizedQuery
except ImportError:
raise ImportError(
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.1'."
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'."
)
logger = logging.getLogger(__name__)
@@ -37,43 +39,82 @@ class OutputData(BaseModel):
class AzureAISearch(VectorStoreBase):
def __init__(self, service_name, collection_name, api_key, embedding_model_dims, use_compression):
"""Initialize the Azure Cognitive Search vector store.
def __init__(
self,
service_name,
collection_name,
api_key,
embedding_model_dims,
compression_type: Optional[str] = None,
use_float16: bool = False,
):
"""
Initialize the Azure AI Search vector store.
Args:
service_name (str): Azure Cognitive Search service name.
service_name (str): Azure AI Search service name.
collection_name (str): Index name.
api_key (str): API key for the Azure Cognitive Search service.
api_key (str): API key for the Azure AI Search service.
embedding_model_dims (int): Dimension of the embedding vector.
use_compression (bool): Use scalar quantization vector compression
compression_type (Optional[str]): Specifies the type of quantization to use.
Allowed values are None (no quantization), "scalar", or "binary".
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
(Note: This flag is preserved from the initial implementation per feedback.)
"""
self.index_name = collection_name
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.use_compression = use_compression
# If compression_type is None, treat it as "none".
self.compression_type = (compression_type or "none").lower()
self.use_float16 = use_float16
self.search_client = SearchClient(
endpoint=f"https://{service_name}.search.windows.net",
index_name=self.index_name,
credential=AzureKeyCredential(api_key),
)
self.index_client = SearchIndexClient(
endpoint=f"https://{service_name}.search.windows.net", credential=AzureKeyCredential(api_key)
endpoint=f"https://{service_name}.search.windows.net",
credential=AzureKeyCredential(api_key),
)
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
self.create_col() # create the collection / index
def create_col(self):
"""Create a new index in Azure Cognitive Search."""
vector_dimensions = self.embedding_model_dims # Set this to the number of dimensions in your vector
if self.use_compression:
"""Create a new index in Azure AI Search."""
# Determine vector type based on use_float16 setting.
if self.use_float16:
vector_type = "Collection(Edm.Half)"
compression_name = "myCompression"
compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]
else:
vector_type = "Collection(Edm.Single)"
compression_name = None
compression_configurations = []
# Configure compression settings based on the specified compression_type.
compression_configurations = []
compression_name = None
if self.compression_type == "scalar":
compression_name = "myCompression"
# For SQ, rescoring defaults to True and oversampling defaults to 4.
compression_configurations = [
ScalarQuantizationCompression(
compression_name=compression_name
# rescoring defaults to True and oversampling defaults to 4
)
]
elif self.compression_type == "binary":
compression_name = "myCompression"
# For BQ, rescoring defaults to True and oversampling defaults to 10.
compression_configurations = [
BinaryQuantizationCompression(
compression_name=compression_name
# rescoring defaults to True and oversampling defaults to 10
)
]
# If no compression is desired, compression_configurations remains empty.
fields = [
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
@@ -82,8 +123,8 @@ class AzureAISearch(VectorStoreBase):
SearchField(
name="vector",
type=vector_type,
searchable=True,
vector_search_dimensions=vector_dimensions,
searchable=True,
vector_search_dimensions=self.embedding_model_dims,
vector_search_profile_name="my-vector-config",
),
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
@@ -91,7 +132,11 @@ class AzureAISearch(VectorStoreBase):
vector_search = VectorSearch(
profiles=[
VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")
VectorSearchProfile(
name="my-vector-config",
algorithm_configuration_name="my-algorithms-config",
compression_name=compression_name if self.compression_type != "none" else None
)
],
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
compressions=compression_configurations,
@@ -101,14 +146,16 @@ class AzureAISearch(VectorStoreBase):
def _generate_document(self, vector, payload, id):
document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
# Extract additional fields if they exist
# Extract additional fields if they exist.
for field in ["user_id", "run_id", "agent_id"]:
if field in payload:
document[field] = payload[field]
return document
# Note: Explicit "insert" calls may later be decoupled from memory management decisions.
def insert(self, vectors, payloads=None, ids=None):
"""Insert vectors into the index.
"""
Insert vectors into the index.
Args:
vectors (List[List[float]]): List of vectors to insert.
@@ -116,61 +163,87 @@ class AzureAISearch(VectorStoreBase):
ids (List[str], optional): List of IDs corresponding to vectors.
"""
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
documents = [
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
self._generate_document(vector, payload, id)
for id, vector, payload in zip(ids, vectors, payloads)
]
self.search_client.upload_documents(documents)
response = self.search_client.upload_documents(documents)
for doc in response:
if not doc.get("status", False):
raise Exception(f"Insert failed for document {doc.get('id')}: {doc}")
return response
def _sanitize_key(self, key: str) -> str:
return re.sub(r"[^\w]", "", key)
def _build_filter_expression(self, filters):
filter_conditions = []
for key, value in filters.items():
# If the value is a string, add quotes
safe_key = self._sanitize_key(key)
if isinstance(value, str):
condition = f"{key} eq '{value}'"
safe_value = value.replace("'", "''")
condition = f"{safe_key} eq '{safe_value}'"
else:
condition = f"{key} eq {value}"
condition = f"{safe_key} eq {value}"
filter_conditions.append(condition)
# Use 'and' to join multiple conditions
filter_expression = " and ".join(filter_conditions)
return filter_expression
def search(self, query, limit=5, filters=None):
"""Search for similar vectors.
def search(self, query, limit=5, filters=None, vector_filter_mode="preFilter"):
"""
Search for similar vectors.
Args:
query (List[float]): Query vectors.
query (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
vector_filter_mode (str): Determines whether filters are applied before or after the vector search.
Known values: "preFilter" (default) and "postFilter".
Returns:
list: Search results.
List[OutputData]: Search results.
"""
# Build filter expression
filter_expression = None
if filters:
filter_expression = self._build_filter_expression(filters)
vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector")
search_results = self.search_client.search(vector_queries=[vector_query], filter=filter_expression, top=limit)
vector_query = VectorizedQuery(
vector=query, k_nearest_neighbors=limit, fields="vector"
)
search_results = self.search_client.search(
vector_queries=[vector_query],
filter=filter_expression,
top=limit,
vector_filter_mode=vector_filter_mode,
)
results = []
for result in search_results:
payload = json.loads(result["payload"])
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
results.append(
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
return results
def delete(self, vector_id):
"""Delete a vector by ID.
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
self.search_client.delete_documents(documents=[{"id": vector_id}])
response = self.search_client.delete_documents(documents=[{"id": vector_id}])
for doc in response:
if not doc.get("status", False):
raise Exception(f"Delete failed for document {vector_id}: {doc}")
logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
return response
def update(self, vector_id, vector=None, payload=None):
"""Update a vector and its payload.
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
@@ -185,10 +258,15 @@ class AzureAISearch(VectorStoreBase):
document["payload"] = json_payload
for field in ["user_id", "run_id", "agent_id"]:
document[field] = payload.get(field)
self.search_client.merge_or_upload_documents(documents=[document])
response = self.search_client.merge_or_upload_documents(documents=[document])
for doc in response:
if not doc.get("status", False):
raise Exception(f"Update failed for document {vector_id}: {doc}")
return response
def get(self, vector_id) -> OutputData:
"""Retrieve a vector by ID.
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
@@ -200,35 +278,43 @@ class AzureAISearch(VectorStoreBase):
result = self.search_client.get_document(key=vector_id)
except ResourceNotFoundError:
return None
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
return OutputData(
id=result["id"], score=None, payload=json.loads(result["payload"])
)
def list_cols(self) -> List[str]:
"""List all collections (indexes).
"""
List all collections (indexes).
Returns:
List[str]: List of index names.
"""
indexes = self.index_client.list_indexes()
return [index.name for index in indexes]
try:
names = self.index_client.list_index_names()
except AttributeError:
names = [index.name for index in self.index_client.list_indexes()]
return names
def delete_col(self):
"""Delete the index."""
self.index_client.delete_index(self.index_name)
def col_info(self):
"""Get information about the index.
"""
Get information about the index.
Returns:
Dict[str, Any]: Index information.
dict: Index information.
"""
index = self.index_client.get_index(self.index_name)
return {"name": index.name, "fields": index.fields}
def list(self, filters=None, limit=100):
"""List all vectors in the index.
"""
List all vectors in the index.
Args:
filters (Dict, optional): Filters to apply to the list.
filters (dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
@@ -238,13 +324,18 @@ class AzureAISearch(VectorStoreBase):
if filters:
filter_expression = self._build_filter_expression(filters)
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
search_results = self.search_client.search(
search_text="*", filter=filter_expression, top=limit
)
results = []
for result in search_results:
payload = json.loads(result["payload"])
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return [results]
results.append(
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
return results
def __del__(self):
"""Close the search client when the object is deleted."""