Support for hybrid search in Azure AI vector store (#2408)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -45,8 +45,10 @@ class AzureAISearch(VectorStoreBase):
|
||||
collection_name,
|
||||
api_key,
|
||||
embedding_model_dims,
|
||||
compression_type: Optional[str] = None,
|
||||
compression_type: Optional[str] = None,
|
||||
use_float16: bool = False,
|
||||
hybrid_search: bool = False,
|
||||
vector_filter_mode: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Azure AI Search vector store.
|
||||
@@ -60,13 +62,17 @@ class AzureAISearch(VectorStoreBase):
|
||||
Allowed values are None (no quantization), "scalar", or "binary".
|
||||
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
|
||||
(Note: This flag is preserved from the initial implementation per feedback.)
|
||||
hybrid_search (bool): Whether to use hybrid search. Default is False.
|
||||
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
|
||||
"""
|
||||
self.index_name = collection_name
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
# If compression_type is None, treat it as "none".
|
||||
self.compression_type = (compression_type or "none").lower()
|
||||
self.compression_type = (compression_type or "none").lower()
|
||||
self.use_float16 = use_float16
|
||||
self.hybrid_search = hybrid_search
|
||||
self.vector_filter_mode = vector_filter_mode
|
||||
|
||||
self.search_client = SearchClient(
|
||||
endpoint=f"https://{service_name}.search.windows.net",
|
||||
@@ -113,8 +119,6 @@ class AzureAISearch(VectorStoreBase):
|
||||
)
|
||||
]
|
||||
# If no compression is desired, compression_configurations remains empty.
|
||||
|
||||
|
||||
fields = [
|
||||
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
||||
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
|
||||
@@ -123,11 +127,11 @@ class AzureAISearch(VectorStoreBase):
|
||||
SearchField(
|
||||
name="vector",
|
||||
type=vector_type,
|
||||
searchable=True,
|
||||
searchable=True,
|
||||
vector_search_dimensions=self.embedding_model_dims,
|
||||
vector_search_profile_name="my-vector-config",
|
||||
),
|
||||
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
||||
SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
||||
]
|
||||
|
||||
vector_search = VectorSearch(
|
||||
@@ -135,7 +139,7 @@ class AzureAISearch(VectorStoreBase):
|
||||
VectorSearchProfile(
|
||||
name="my-vector-config",
|
||||
algorithm_configuration_name="my-algorithms-config",
|
||||
compression_name=compression_name if self.compression_type != "none" else None
|
||||
compression_name=compression_name if self.compression_type != "none" else None,
|
||||
)
|
||||
],
|
||||
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
|
||||
@@ -164,8 +168,7 @@ class AzureAISearch(VectorStoreBase):
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
||||
documents = [
|
||||
self._generate_document(vector, payload, id)
|
||||
for id, vector, payload in zip(ids, vectors, payloads)
|
||||
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
|
||||
]
|
||||
response = self.search_client.upload_documents(documents)
|
||||
for doc in response:
|
||||
@@ -189,12 +192,13 @@ class AzureAISearch(VectorStoreBase):
|
||||
filter_expression = " and ".join(filter_conditions)
|
||||
return filter_expression
|
||||
|
||||
def search(self, query, limit=5, filters=None):
|
||||
def search(self, query, vectors, limit=5, filters=None):
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (List[float]): Query vector.
|
||||
query (str): Query.
|
||||
vectors (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
@@ -205,23 +209,28 @@ class AzureAISearch(VectorStoreBase):
|
||||
if filters:
|
||||
filter_expression = self._build_filter_expression(filters)
|
||||
|
||||
vector_query = VectorizedQuery(
|
||||
vector=query, k_nearest_neighbors=limit, fields="vector"
|
||||
)
|
||||
search_results = self.search_client.search(
|
||||
vector_queries=[vector_query],
|
||||
filter=filter_expression,
|
||||
top=limit
|
||||
)
|
||||
vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
|
||||
if self.hybrid_search:
|
||||
search_results = self.search_client.search(
|
||||
search_text=query,
|
||||
vector_queries=[vector_query],
|
||||
filter=filter_expression,
|
||||
top=limit,
|
||||
vector_filter_mode=self.vector_filter_mode,
|
||||
search_fields=["payload"],
|
||||
)
|
||||
else:
|
||||
search_results = self.search_client.search(
|
||||
vector_queries=[vector_query],
|
||||
filter=filter_expression,
|
||||
top=limit,
|
||||
vector_filter_mode=self.vector_filter_mode,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in search_results:
|
||||
payload = json.loads(result["payload"])
|
||||
results.append(
|
||||
OutputData(
|
||||
id=result["id"], score=result["@search.score"], payload=payload
|
||||
)
|
||||
)
|
||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||
return results
|
||||
|
||||
def delete(self, vector_id):
|
||||
@@ -275,9 +284,7 @@ class AzureAISearch(VectorStoreBase):
|
||||
result = self.search_client.get_document(key=vector_id)
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
return OutputData(
|
||||
id=result["id"], score=None, payload=json.loads(result["payload"])
|
||||
)
|
||||
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
@@ -321,17 +328,11 @@ class AzureAISearch(VectorStoreBase):
|
||||
if filters:
|
||||
filter_expression = self._build_filter_expression(filters)
|
||||
|
||||
search_results = self.search_client.search(
|
||||
search_text="*", filter=filter_expression, top=limit
|
||||
)
|
||||
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
|
||||
results = []
|
||||
for result in search_results:
|
||||
payload = json.loads(result["payload"])
|
||||
results.append(
|
||||
OutputData(
|
||||
id=result["id"], score=result["@search.score"], payload=payload
|
||||
)
|
||||
)
|
||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||
return [results]
|
||||
|
||||
def __del__(self):
|
||||
|
||||
Reference in New Issue
Block a user