Fix query filter in azure ai search (#2171)
This commit is contained in:
@@ -76,6 +76,9 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
||||||
|
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
|
||||||
|
SimpleField(name="run_id", type=SearchFieldDataType.String, filterable=True),
|
||||||
|
SimpleField(name="agent_id", type=SearchFieldDataType.String, filterable=True),
|
||||||
SearchField(
|
SearchField(
|
||||||
name="vector",
|
name="vector",
|
||||||
type=vector_type,
|
type=vector_type,
|
||||||
@@ -96,6 +99,14 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
|
index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
|
||||||
self.index_client.create_or_update_index(index)
|
self.index_client.create_or_update_index(index)
|
||||||
|
|
||||||
|
def _generate_document(self, vector, payload, id):
|
||||||
|
document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
|
||||||
|
# Extract additional fields if they exist
|
||||||
|
for field in ["user_id", "run_id", "agent_id"]:
|
||||||
|
if field in payload:
|
||||||
|
document[field] = payload[field]
|
||||||
|
return document
|
||||||
|
|
||||||
def insert(self, vectors, payloads=None, ids=None):
|
def insert(self, vectors, payloads=None, ids=None):
|
||||||
"""Insert vectors into the index.
|
"""Insert vectors into the index.
|
||||||
|
|
||||||
@@ -105,12 +116,26 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
||||||
|
|
||||||
documents = [
|
documents = [
|
||||||
{"id": id, "vector": vector, "payload": json.dumps(payload)}
|
self._generate_document(vector, payload, id)
|
||||||
for id, vector, payload in zip(ids, vectors, payloads)
|
for id, vector, payload in zip(ids, vectors, payloads)
|
||||||
]
|
]
|
||||||
self.search_client.upload_documents(documents)
|
self.search_client.upload_documents(documents)
|
||||||
|
|
||||||
|
def _build_filter_expression(self, filters):
|
||||||
|
filter_conditions = []
|
||||||
|
for key, value in filters.items():
|
||||||
|
# If the value is a string, add quotes
|
||||||
|
if isinstance(value, str):
|
||||||
|
condition = f"{key} eq '{value}'"
|
||||||
|
else:
|
||||||
|
condition = f"{key} eq {value}"
|
||||||
|
filter_conditions.append(condition)
|
||||||
|
# Use 'and' to join multiple conditions
|
||||||
|
filter_expression = ' and '.join(filter_conditions)
|
||||||
|
return filter_expression
|
||||||
|
|
||||||
def search(self, query, limit=5, filters=None):
|
def search(self, query, limit=5, filters=None):
|
||||||
"""Search for similar vectors.
|
"""Search for similar vectors.
|
||||||
|
|
||||||
@@ -122,17 +147,23 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
Returns:
|
Returns:
|
||||||
list: Search results.
|
list: Search results.
|
||||||
"""
|
"""
|
||||||
|
# Build filter expression
|
||||||
|
filter_expression = None
|
||||||
|
if filters:
|
||||||
|
filter_expression = self._build_filter_expression(filters)
|
||||||
|
|
||||||
vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector")
|
vector_query = VectorizedQuery(
|
||||||
search_results = self.search_client.search(vector_queries=[vector_query], top=limit)
|
vector=query, k_nearest_neighbors=limit, fields="vector"
|
||||||
|
)
|
||||||
|
search_results = self.search_client.search(
|
||||||
|
vector_queries=[vector_query],
|
||||||
|
filter=filter_expression,
|
||||||
|
top=limit
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
payload = json.loads(result["payload"])
|
payload = json.loads(result["payload"])
|
||||||
if filters:
|
|
||||||
for key, value in filters.items():
|
|
||||||
if key not in payload or payload[key] != value:
|
|
||||||
continue
|
|
||||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -143,6 +174,7 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
vector_id (str): ID of the vector to delete.
|
vector_id (str): ID of the vector to delete.
|
||||||
"""
|
"""
|
||||||
self.search_client.delete_documents(documents=[{"id": vector_id}])
|
self.search_client.delete_documents(documents=[{"id": vector_id}])
|
||||||
|
logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
|
||||||
|
|
||||||
def update(self, vector_id, vector=None, payload=None):
|
def update(self, vector_id, vector=None, payload=None):
|
||||||
"""Update a vector and its payload.
|
"""Update a vector and its payload.
|
||||||
@@ -156,7 +188,10 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
if vector:
|
if vector:
|
||||||
document["vector"] = vector
|
document["vector"] = vector
|
||||||
if payload:
|
if payload:
|
||||||
document["payload"] = json.dumps(payload)
|
json_payload = json.dumps(payload)
|
||||||
|
document["payload"] = json_payload
|
||||||
|
for field in ["user_id", "run_id", "agent_id"]:
|
||||||
|
document[field] = payload.get(field)
|
||||||
self.search_client.merge_or_upload_documents(documents=[document])
|
self.search_client.merge_or_upload_documents(documents=[document])
|
||||||
|
|
||||||
def get(self, vector_id) -> OutputData:
|
def get(self, vector_id) -> OutputData:
|
||||||
@@ -206,17 +241,18 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
Returns:
|
Returns:
|
||||||
List[OutputData]: List of vectors.
|
List[OutputData]: List of vectors.
|
||||||
"""
|
"""
|
||||||
search_results = self.search_client.search(search_text="*", top=limit)
|
filter_expression = None
|
||||||
|
if filters:
|
||||||
|
filter_expression = self._build_filter_expression(filters)
|
||||||
|
|
||||||
|
search_results = self.search_client.search(
|
||||||
|
search_text="*",
|
||||||
|
filter=filter_expression,
|
||||||
|
top=limit
|
||||||
|
)
|
||||||
results = []
|
results = []
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
payload = json.loads(result["payload"])
|
payload = json.loads(result["payload"])
|
||||||
include_result = True
|
|
||||||
if filters:
|
|
||||||
for key, value in filters.items():
|
|
||||||
if (key not in payload) or (payload[key] != filters[key]):
|
|
||||||
include_result = False
|
|
||||||
break
|
|
||||||
if include_result:
|
|
||||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||||
|
|
||||||
return [results]
|
return [results]
|
||||||
|
|||||||
Reference in New Issue
Block a user