Support for hybrid search in Azure AI vector store (#2408)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Dev Khant
2025-03-20 22:57:00 +05:30
committed by GitHub
parent 8b9a8e5825
commit 8e6a08aa83
24 changed files with 275 additions and 294 deletions

View File

@@ -45,8 +45,10 @@ class AzureAISearch(VectorStoreBase):
collection_name,
api_key,
embedding_model_dims,
compression_type: Optional[str] = None,
compression_type: Optional[str] = None,
use_float16: bool = False,
hybrid_search: bool = False,
vector_filter_mode: Optional[str] = None,
):
"""
Initialize the Azure AI Search vector store.
@@ -60,13 +62,17 @@ class AzureAISearch(VectorStoreBase):
Allowed values are None (no quantization), "scalar", or "binary".
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
(Note: This flag is preserved from the initial implementation per feedback.)
hybrid_search (bool): Whether to use hybrid search. Default is False.
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
"""
self.index_name = collection_name
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
# If compression_type is None, treat it as "none".
self.compression_type = (compression_type or "none").lower()
self.compression_type = (compression_type or "none").lower()
self.use_float16 = use_float16
self.hybrid_search = hybrid_search
self.vector_filter_mode = vector_filter_mode
self.search_client = SearchClient(
endpoint=f"https://{service_name}.search.windows.net",
@@ -113,8 +119,6 @@ class AzureAISearch(VectorStoreBase):
)
]
# If no compression is desired, compression_configurations remains empty.
fields = [
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
@@ -123,11 +127,11 @@ class AzureAISearch(VectorStoreBase):
SearchField(
name="vector",
type=vector_type,
searchable=True,
searchable=True,
vector_search_dimensions=self.embedding_model_dims,
vector_search_profile_name="my-vector-config",
),
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
]
vector_search = VectorSearch(
@@ -135,7 +139,7 @@ class AzureAISearch(VectorStoreBase):
VectorSearchProfile(
name="my-vector-config",
algorithm_configuration_name="my-algorithms-config",
compression_name=compression_name if self.compression_type != "none" else None
compression_name=compression_name if self.compression_type != "none" else None,
)
],
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
@@ -164,8 +168,7 @@ class AzureAISearch(VectorStoreBase):
"""
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
documents = [
self._generate_document(vector, payload, id)
for id, vector, payload in zip(ids, vectors, payloads)
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
]
response = self.search_client.upload_documents(documents)
for doc in response:
@@ -189,12 +192,13 @@ class AzureAISearch(VectorStoreBase):
filter_expression = " and ".join(filter_conditions)
return filter_expression
def search(self, query, limit=5, filters=None):
def search(self, query, vectors, limit=5, filters=None):
"""
Search for similar vectors.
Args:
query (List[float]): Query vector.
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
@@ -205,23 +209,28 @@ class AzureAISearch(VectorStoreBase):
if filters:
filter_expression = self._build_filter_expression(filters)
vector_query = VectorizedQuery(
vector=query, k_nearest_neighbors=limit, fields="vector"
)
search_results = self.search_client.search(
vector_queries=[vector_query],
filter=filter_expression,
top=limit
)
vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
if self.hybrid_search:
search_results = self.search_client.search(
search_text=query,
vector_queries=[vector_query],
filter=filter_expression,
top=limit,
vector_filter_mode=self.vector_filter_mode,
search_fields=["payload"],
)
else:
search_results = self.search_client.search(
vector_queries=[vector_query],
filter=filter_expression,
top=limit,
vector_filter_mode=self.vector_filter_mode,
)
results = []
for result in search_results:
payload = json.loads(result["payload"])
results.append(
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return results
def delete(self, vector_id):
@@ -275,9 +284,7 @@ class AzureAISearch(VectorStoreBase):
result = self.search_client.get_document(key=vector_id)
except ResourceNotFoundError:
return None
return OutputData(
id=result["id"], score=None, payload=json.loads(result["payload"])
)
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
def list_cols(self) -> List[str]:
"""
@@ -321,17 +328,11 @@ class AzureAISearch(VectorStoreBase):
if filters:
filter_expression = self._build_filter_expression(filters)
search_results = self.search_client.search(
search_text="*", filter=filter_expression, top=limit
)
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
results = []
for result in search_results:
payload = json.loads(result["payload"])
results.append(
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return [results]
def __del__(self):