From 8d3c8c695d6d52037c1c3d306d89a409b483cced Mon Sep 17 00:00:00 2001 From: junmo1215 <997128120@qq.com> Date: Fri, 31 Jan 2025 18:08:06 +0800 Subject: [PATCH] Fix query filter in azure ai search (#2171) --- mem0/vector_stores/azure_ai_search.py | 70 ++++++++++++++++++++------- 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py index 1c68c4f7..c7d5cb2d 100644 --- a/mem0/vector_stores/azure_ai_search.py +++ b/mem0/vector_stores/azure_ai_search.py @@ -76,6 +76,9 @@ class AzureAISearch(VectorStoreBase): fields = [ SimpleField(name="id", type=SearchFieldDataType.String, key=True), + SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True), + SimpleField(name="run_id", type=SearchFieldDataType.String, filterable=True), + SimpleField(name="agent_id", type=SearchFieldDataType.String, filterable=True), SearchField( name="vector", type=vector_type, @@ -96,6 +99,14 @@ class AzureAISearch(VectorStoreBase): index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search) self.index_client.create_or_update_index(index) + def _generate_document(self, vector, payload, id): + document = {"id": id, "vector": vector, "payload": json.dumps(payload)} + # Extract additional fields if they exist + for field in ["user_id", "run_id", "agent_id"]: + if field in payload: + document[field] = payload[field] + return document + def insert(self, vectors, payloads=None, ids=None): """Insert vectors into the index. @@ -105,12 +116,26 @@ class AzureAISearch(VectorStoreBase): ids (List[str], optional): List of IDs corresponding to vectors. """ logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}") + documents = [ - {"id": id, "vector": vector, "payload": json.dumps(payload)} + self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads) ] self.search_client.upload_documents(documents) + def _build_filter_expression(self, filters): + filter_conditions = [] + for key, value in filters.items(): + # If the value is a string, add quotes + if isinstance(value, str): + condition = f"{key} eq '{value}'" + else: + condition = f"{key} eq {value}" + filter_conditions.append(condition) + # Use 'and' to join multiple conditions + filter_expression = ' and '.join(filter_conditions) + return filter_expression + def search(self, query, limit=5, filters=None): """Search for similar vectors. @@ -122,17 +147,23 @@ class AzureAISearch(VectorStoreBase): Returns: list: Search results. """ + # Build filter expression + filter_expression = None + if filters: + filter_expression = self._build_filter_expression(filters) - vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector") - search_results = self.search_client.search(vector_queries=[vector_query], top=limit) + vector_query = VectorizedQuery( + vector=query, k_nearest_neighbors=limit, fields="vector" + ) + search_results = self.search_client.search( + vector_queries=[vector_query], + filter=filter_expression, + top=limit + ) results = [] for result in search_results: payload = json.loads(result["payload"]) - if filters: - for key, value in filters.items(): - if key not in payload or payload[key] != value: - continue results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) return results @@ -143,6 +174,7 @@ class AzureAISearch(VectorStoreBase): vector_id (str): ID of the vector to delete. """ self.search_client.delete_documents(documents=[{"id": vector_id}]) + logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.") def update(self, vector_id, vector=None, payload=None): """Update a vector and its payload. @@ -156,7 +188,10 @@ class AzureAISearch(VectorStoreBase): if vector: document["vector"] = vector if payload: - document["payload"] = json.dumps(payload) + json_payload = json.dumps(payload) + document["payload"] = json_payload + for field in ["user_id", "run_id", "agent_id"]: + document[field] = payload.get(field) self.search_client.merge_or_upload_documents(documents=[document]) def get(self, vector_id) -> OutputData: @@ -206,18 +241,19 @@ class AzureAISearch(VectorStoreBase): Returns: List[OutputData]: List of vectors. """ - search_results = self.search_client.search(search_text="*", top=limit) + filter_expression = None + if filters: + filter_expression = self._build_filter_expression(filters) + + search_results = self.search_client.search( + search_text="*", + filter=filter_expression, + top=limit + ) results = [] for result in search_results: payload = json.loads(result["payload"]) - include_result = True - if filters: - for key, value in filters.items(): - if (key not in payload) or (payload[key] != filters[key]): - include_result = False - break - if include_result: - results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) + results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) return [results]