feat: enhance Azure AI Search Integration with Binary Quantization, Pre/Post Filter Options, and user agent header (#2354)
This commit is contained in:
@@ -1,12 +1,14 @@
|
|||||||
[Azure AI Search](https://learn.microsoft.com/en-us/azure/search/search-what-is-azure-search/) (formerly known as "Azure Cognitive Search") provides secure information retrieval at scale over user-owned content in traditional and generative AI search applications.
|
# Azure AI Search
|
||||||
|
|
||||||
### Usage
|
[Azure AI Search](https://learn.microsoft.com/azure/search/search-what-is-azure-search/) (formerly known as "Azure Cognitive Search") provides secure information retrieval at scale over user-owned content in traditional and generative AI search applications.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from mem0 import Memory
|
from mem0 import Memory
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = "sk-xx" #this key is used for embedding purpose
|
os.environ["OPENAI_API_KEY"] = "sk-xx" # This key is used for embedding purpose
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"vector_store": {
|
"vector_store": {
|
||||||
@@ -15,8 +17,8 @@ config = {
|
|||||||
"service_name": "ai-search-test",
|
"service_name": "ai-search-test",
|
||||||
"api_key": "*****",
|
"api_key": "*****",
|
||||||
"collection_name": "mem0",
|
"collection_name": "mem0",
|
||||||
"embedding_model_dims": 1536 ,
|
"embedding_model_dims": 1536,
|
||||||
"use_compression": False
|
"compression_type": "none"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -25,20 +27,61 @@ m = Memory.from_config(config)
|
|||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||||
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||||
{"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."},
|
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||||
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||||
]
|
]
|
||||||
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||||
```
|
```
|
||||||
|
|
||||||
### Config
|
## Advanced Usage
|
||||||
|
|
||||||
Let's see the available parameters for the `qdrant` config:
|
```python
|
||||||
service_name (str): Azure Cognitive Search service name.
|
# Search with specific filter mode
|
||||||
| Parameter | Description | Default Value |
|
result = m.search(
|
||||||
| --- | --- | --- |
|
"sci-fi movies",
|
||||||
| `service_name` | Azure AI Search service name | `None` |
|
filters={"user_id": "alice"},
|
||||||
| `api_key` | API key of the Azure AI Search service | `None` |
|
limit=5,
|
||||||
| `collection_name` | The name of the collection/index to store the vectors, it will be created automatically if not exist | `mem0` |
|
vector_filter_mode="preFilter" # Apply filters before vector search
|
||||||
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
)
|
||||||
| `use_compression` | Use scalar quantization vector compression | False |
|
|
||||||
|
# Using binary compression for large vector collections
|
||||||
|
config = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "azure_ai_search",
|
||||||
|
"config": {
|
||||||
|
"service_name": "ai-search-test",
|
||||||
|
"api_key": "*****",
|
||||||
|
"collection_name": "mem0",
|
||||||
|
"embedding_model_dims": 1536,
|
||||||
|
"compression_type": "binary",
|
||||||
|
"use_float16": True # Use half precision for storage efficiency
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Parameters
|
||||||
|
|
||||||
|
| Parameter | Description | Default Value | Options |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `service_name` | Azure AI Search service name | Required | - |
|
||||||
|
| `api_key` | API key of the Azure AI Search service | Required | - |
|
||||||
|
| `collection_name` | The name of the collection/index to store vectors | `mem0` | Any valid index name |
|
||||||
|
| `embedding_model_dims` | Dimensions of the embedding model | `1536` | Any integer value |
|
||||||
|
| `compression_type` | Type of vector compression to use | `none` | `none`, `scalar`, `binary` |
|
||||||
|
| `use_float16` | Store vectors in half precision (Edm.Half) | `False` | `True`, `False` |
|
||||||
|
|
||||||
|
## Notes on Configuration Options
|
||||||
|
|
||||||
|
- **compression_type**:
|
||||||
|
- `none`: No compression, uses full vector precision
|
||||||
|
- `scalar`: Scalar quantization with reasonable balance of speed and accuracy
|
||||||
|
- `binary`: Binary quantization for maximum compression with some accuracy trade-off
|
||||||
|
|
||||||
|
- **vector_filter_mode**:
|
||||||
|
- `preFilter`: Applies filters before vector search (faster)
|
||||||
|
- `postFilter`: Applies filters after vector search (may provide better relevance)
|
||||||
|
|
||||||
|
- **use_float16**: Using half precision (float16) reduces storage requirements but may slightly impact accuracy. Useful for very large vector collections.
|
||||||
|
|
||||||
|
- **Filterable Fields**: The implementation automatically extracts `user_id`, `run_id`, and `agent_id` fields from payloads for filtering.
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -12,6 +13,7 @@ try:
|
|||||||
from azure.search.documents import SearchClient
|
from azure.search.documents import SearchClient
|
||||||
from azure.search.documents.indexes import SearchIndexClient
|
from azure.search.documents.indexes import SearchIndexClient
|
||||||
from azure.search.documents.indexes.models import (
|
from azure.search.documents.indexes.models import (
|
||||||
|
BinaryQuantizationCompression,
|
||||||
HnswAlgorithmConfiguration,
|
HnswAlgorithmConfiguration,
|
||||||
ScalarQuantizationCompression,
|
ScalarQuantizationCompression,
|
||||||
SearchField,
|
SearchField,
|
||||||
@@ -24,7 +26,7 @@ try:
|
|||||||
from azure.search.documents.models import VectorizedQuery
|
from azure.search.documents.models import VectorizedQuery
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.1'."
|
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -37,42 +39,81 @@ class OutputData(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AzureAISearch(VectorStoreBase):
|
class AzureAISearch(VectorStoreBase):
|
||||||
def __init__(self, service_name, collection_name, api_key, embedding_model_dims, use_compression):
|
def __init__(
|
||||||
"""Initialize the Azure Cognitive Search vector store.
|
self,
|
||||||
|
service_name,
|
||||||
|
collection_name,
|
||||||
|
api_key,
|
||||||
|
embedding_model_dims,
|
||||||
|
compression_type: Optional[str] = None,
|
||||||
|
use_float16: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Azure AI Search vector store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_name (str): Azure Cognitive Search service name.
|
service_name (str): Azure AI Search service name.
|
||||||
collection_name (str): Index name.
|
collection_name (str): Index name.
|
||||||
api_key (str): API key for the Azure Cognitive Search service.
|
api_key (str): API key for the Azure AI Search service.
|
||||||
embedding_model_dims (int): Dimension of the embedding vector.
|
embedding_model_dims (int): Dimension of the embedding vector.
|
||||||
use_compression (bool): Use scalar quantization vector compression
|
compression_type (Optional[str]): Specifies the type of quantization to use.
|
||||||
|
Allowed values are None (no quantization), "scalar", or "binary".
|
||||||
|
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
|
||||||
|
(Note: This flag is preserved from the initial implementation per feedback.)
|
||||||
"""
|
"""
|
||||||
self.index_name = collection_name
|
self.index_name = collection_name
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.embedding_model_dims = embedding_model_dims
|
self.embedding_model_dims = embedding_model_dims
|
||||||
self.use_compression = use_compression
|
# If compression_type is None, treat it as "none".
|
||||||
|
self.compression_type = (compression_type or "none").lower()
|
||||||
|
self.use_float16 = use_float16
|
||||||
|
|
||||||
self.search_client = SearchClient(
|
self.search_client = SearchClient(
|
||||||
endpoint=f"https://{service_name}.search.windows.net",
|
endpoint=f"https://{service_name}.search.windows.net",
|
||||||
index_name=self.index_name,
|
index_name=self.index_name,
|
||||||
credential=AzureKeyCredential(api_key),
|
credential=AzureKeyCredential(api_key),
|
||||||
)
|
)
|
||||||
self.index_client = SearchIndexClient(
|
self.index_client = SearchIndexClient(
|
||||||
endpoint=f"https://{service_name}.search.windows.net", credential=AzureKeyCredential(api_key)
|
endpoint=f"https://{service_name}.search.windows.net",
|
||||||
|
credential=AzureKeyCredential(api_key),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||||
|
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||||
|
|
||||||
self.create_col() # create the collection / index
|
self.create_col() # create the collection / index
|
||||||
|
|
||||||
def create_col(self):
|
def create_col(self):
|
||||||
"""Create a new index in Azure Cognitive Search."""
|
"""Create a new index in Azure AI Search."""
|
||||||
vector_dimensions = self.embedding_model_dims # Set this to the number of dimensions in your vector
|
# Determine vector type based on use_float16 setting.
|
||||||
|
if self.use_float16:
|
||||||
if self.use_compression:
|
|
||||||
vector_type = "Collection(Edm.Half)"
|
vector_type = "Collection(Edm.Half)"
|
||||||
compression_name = "myCompression"
|
|
||||||
compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]
|
|
||||||
else:
|
else:
|
||||||
vector_type = "Collection(Edm.Single)"
|
vector_type = "Collection(Edm.Single)"
|
||||||
compression_name = None
|
|
||||||
compression_configurations = []
|
# Configure compression settings based on the specified compression_type.
|
||||||
|
compression_configurations = []
|
||||||
|
compression_name = None
|
||||||
|
if self.compression_type == "scalar":
|
||||||
|
compression_name = "myCompression"
|
||||||
|
# For SQ, rescoring defaults to True and oversampling defaults to 4.
|
||||||
|
compression_configurations = [
|
||||||
|
ScalarQuantizationCompression(
|
||||||
|
compression_name=compression_name
|
||||||
|
# rescoring defaults to True and oversampling defaults to 4
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif self.compression_type == "binary":
|
||||||
|
compression_name = "myCompression"
|
||||||
|
# For BQ, rescoring defaults to True and oversampling defaults to 10.
|
||||||
|
compression_configurations = [
|
||||||
|
BinaryQuantizationCompression(
|
||||||
|
compression_name=compression_name
|
||||||
|
# rescoring defaults to True and oversampling defaults to 10
|
||||||
|
)
|
||||||
|
]
|
||||||
|
# If no compression is desired, compression_configurations remains empty.
|
||||||
|
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
||||||
@@ -83,7 +124,7 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
name="vector",
|
name="vector",
|
||||||
type=vector_type,
|
type=vector_type,
|
||||||
searchable=True,
|
searchable=True,
|
||||||
vector_search_dimensions=vector_dimensions,
|
vector_search_dimensions=self.embedding_model_dims,
|
||||||
vector_search_profile_name="my-vector-config",
|
vector_search_profile_name="my-vector-config",
|
||||||
),
|
),
|
||||||
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
||||||
@@ -91,7 +132,11 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
|
|
||||||
vector_search = VectorSearch(
|
vector_search = VectorSearch(
|
||||||
profiles=[
|
profiles=[
|
||||||
VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")
|
VectorSearchProfile(
|
||||||
|
name="my-vector-config",
|
||||||
|
algorithm_configuration_name="my-algorithms-config",
|
||||||
|
compression_name=compression_name if self.compression_type != "none" else None
|
||||||
|
)
|
||||||
],
|
],
|
||||||
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
|
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
|
||||||
compressions=compression_configurations,
|
compressions=compression_configurations,
|
||||||
@@ -101,14 +146,16 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
|
|
||||||
def _generate_document(self, vector, payload, id):
|
def _generate_document(self, vector, payload, id):
|
||||||
document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
|
document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
|
||||||
# Extract additional fields if they exist
|
# Extract additional fields if they exist.
|
||||||
for field in ["user_id", "run_id", "agent_id"]:
|
for field in ["user_id", "run_id", "agent_id"]:
|
||||||
if field in payload:
|
if field in payload:
|
||||||
document[field] = payload[field]
|
document[field] = payload[field]
|
||||||
return document
|
return document
|
||||||
|
|
||||||
|
# Note: Explicit "insert" calls may later be decoupled from memory management decisions.
|
||||||
def insert(self, vectors, payloads=None, ids=None):
|
def insert(self, vectors, payloads=None, ids=None):
|
||||||
"""Insert vectors into the index.
|
"""
|
||||||
|
Insert vectors into the index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vectors (List[List[float]]): List of vectors to insert.
|
vectors (List[List[float]]): List of vectors to insert.
|
||||||
@@ -116,61 +163,87 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
||||||
|
|
||||||
documents = [
|
documents = [
|
||||||
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
|
self._generate_document(vector, payload, id)
|
||||||
|
for id, vector, payload in zip(ids, vectors, payloads)
|
||||||
]
|
]
|
||||||
self.search_client.upload_documents(documents)
|
response = self.search_client.upload_documents(documents)
|
||||||
|
for doc in response:
|
||||||
|
if not doc.get("status", False):
|
||||||
|
raise Exception(f"Insert failed for document {doc.get('id')}: {doc}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _sanitize_key(self, key: str) -> str:
|
||||||
|
return re.sub(r"[^\w]", "", key)
|
||||||
|
|
||||||
def _build_filter_expression(self, filters):
|
def _build_filter_expression(self, filters):
|
||||||
filter_conditions = []
|
filter_conditions = []
|
||||||
for key, value in filters.items():
|
for key, value in filters.items():
|
||||||
# If the value is a string, add quotes
|
safe_key = self._sanitize_key(key)
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
condition = f"{key} eq '{value}'"
|
safe_value = value.replace("'", "''")
|
||||||
|
condition = f"{safe_key} eq '{safe_value}'"
|
||||||
else:
|
else:
|
||||||
condition = f"{key} eq {value}"
|
condition = f"{safe_key} eq {value}"
|
||||||
filter_conditions.append(condition)
|
filter_conditions.append(condition)
|
||||||
# Use 'and' to join multiple conditions
|
|
||||||
filter_expression = " and ".join(filter_conditions)
|
filter_expression = " and ".join(filter_conditions)
|
||||||
return filter_expression
|
return filter_expression
|
||||||
|
|
||||||
def search(self, query, limit=5, filters=None):
|
def search(self, query, limit=5, filters=None, vector_filter_mode="preFilter"):
|
||||||
"""Search for similar vectors.
|
"""
|
||||||
|
Search for similar vectors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (List[float]): Query vectors.
|
query (List[float]): Query vector.
|
||||||
limit (int, optional): Number of results to return. Defaults to 5.
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
vector_filter_mode (str): Determines whether filters are applied before or after the vector search.
|
||||||
|
Known values: "preFilter" (default) and "postFilter".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: Search results.
|
List[OutputData]: Search results.
|
||||||
"""
|
"""
|
||||||
# Build filter expression
|
|
||||||
filter_expression = None
|
filter_expression = None
|
||||||
if filters:
|
if filters:
|
||||||
filter_expression = self._build_filter_expression(filters)
|
filter_expression = self._build_filter_expression(filters)
|
||||||
|
|
||||||
vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector")
|
vector_query = VectorizedQuery(
|
||||||
search_results = self.search_client.search(vector_queries=[vector_query], filter=filter_expression, top=limit)
|
vector=query, k_nearest_neighbors=limit, fields="vector"
|
||||||
|
)
|
||||||
|
search_results = self.search_client.search(
|
||||||
|
vector_queries=[vector_query],
|
||||||
|
filter=filter_expression,
|
||||||
|
top=limit,
|
||||||
|
vector_filter_mode=vector_filter_mode,
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
payload = json.loads(result["payload"])
|
payload = json.loads(result["payload"])
|
||||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
results.append(
|
||||||
|
OutputData(
|
||||||
|
id=result["id"], score=result["@search.score"], payload=payload
|
||||||
|
)
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def delete(self, vector_id):
|
def delete(self, vector_id):
|
||||||
"""Delete a vector by ID.
|
"""
|
||||||
|
Delete a vector by ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector_id (str): ID of the vector to delete.
|
vector_id (str): ID of the vector to delete.
|
||||||
"""
|
"""
|
||||||
self.search_client.delete_documents(documents=[{"id": vector_id}])
|
response = self.search_client.delete_documents(documents=[{"id": vector_id}])
|
||||||
|
for doc in response:
|
||||||
|
if not doc.get("status", False):
|
||||||
|
raise Exception(f"Delete failed for document {vector_id}: {doc}")
|
||||||
logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
|
logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
|
||||||
|
return response
|
||||||
|
|
||||||
def update(self, vector_id, vector=None, payload=None):
|
def update(self, vector_id, vector=None, payload=None):
|
||||||
"""Update a vector and its payload.
|
"""
|
||||||
|
Update a vector and its payload.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector_id (str): ID of the vector to update.
|
vector_id (str): ID of the vector to update.
|
||||||
@@ -185,10 +258,15 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
document["payload"] = json_payload
|
document["payload"] = json_payload
|
||||||
for field in ["user_id", "run_id", "agent_id"]:
|
for field in ["user_id", "run_id", "agent_id"]:
|
||||||
document[field] = payload.get(field)
|
document[field] = payload.get(field)
|
||||||
self.search_client.merge_or_upload_documents(documents=[document])
|
response = self.search_client.merge_or_upload_documents(documents=[document])
|
||||||
|
for doc in response:
|
||||||
|
if not doc.get("status", False):
|
||||||
|
raise Exception(f"Update failed for document {vector_id}: {doc}")
|
||||||
|
return response
|
||||||
|
|
||||||
def get(self, vector_id) -> OutputData:
|
def get(self, vector_id) -> OutputData:
|
||||||
"""Retrieve a vector by ID.
|
"""
|
||||||
|
Retrieve a vector by ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector_id (str): ID of the vector to retrieve.
|
vector_id (str): ID of the vector to retrieve.
|
||||||
@@ -200,35 +278,43 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
result = self.search_client.get_document(key=vector_id)
|
result = self.search_client.get_document(key=vector_id)
|
||||||
except ResourceNotFoundError:
|
except ResourceNotFoundError:
|
||||||
return None
|
return None
|
||||||
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
|
return OutputData(
|
||||||
|
id=result["id"], score=None, payload=json.loads(result["payload"])
|
||||||
|
)
|
||||||
|
|
||||||
def list_cols(self) -> List[str]:
|
def list_cols(self) -> List[str]:
|
||||||
"""List all collections (indexes).
|
"""
|
||||||
|
List all collections (indexes).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: List of index names.
|
List[str]: List of index names.
|
||||||
"""
|
"""
|
||||||
indexes = self.index_client.list_indexes()
|
try:
|
||||||
return [index.name for index in indexes]
|
names = self.index_client.list_index_names()
|
||||||
|
except AttributeError:
|
||||||
|
names = [index.name for index in self.index_client.list_indexes()]
|
||||||
|
return names
|
||||||
|
|
||||||
def delete_col(self):
|
def delete_col(self):
|
||||||
"""Delete the index."""
|
"""Delete the index."""
|
||||||
self.index_client.delete_index(self.index_name)
|
self.index_client.delete_index(self.index_name)
|
||||||
|
|
||||||
def col_info(self):
|
def col_info(self):
|
||||||
"""Get information about the index.
|
"""
|
||||||
|
Get information about the index.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: Index information.
|
dict: Index information.
|
||||||
"""
|
"""
|
||||||
index = self.index_client.get_index(self.index_name)
|
index = self.index_client.get_index(self.index_name)
|
||||||
return {"name": index.name, "fields": index.fields}
|
return {"name": index.name, "fields": index.fields}
|
||||||
|
|
||||||
def list(self, filters=None, limit=100):
|
def list(self, filters=None, limit=100):
|
||||||
"""List all vectors in the index.
|
"""
|
||||||
|
List all vectors in the index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filters (Dict, optional): Filters to apply to the list.
|
filters (dict, optional): Filters to apply to the list.
|
||||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -238,13 +324,18 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
if filters:
|
if filters:
|
||||||
filter_expression = self._build_filter_expression(filters)
|
filter_expression = self._build_filter_expression(filters)
|
||||||
|
|
||||||
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
|
search_results = self.search_client.search(
|
||||||
|
search_text="*", filter=filter_expression, top=limit
|
||||||
|
)
|
||||||
results = []
|
results = []
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
payload = json.loads(result["payload"])
|
payload = json.loads(result["payload"])
|
||||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
results.append(
|
||||||
|
OutputData(
|
||||||
return [results]
|
id=result["id"], score=result["@search.score"], payload=payload
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""Close the search client when the object is deleted."""
|
"""Close the search client when the object is deleted."""
|
||||||
|
|||||||
90
poetry.lock
generated
90
poetry.lock
generated
@@ -112,7 +112,7 @@ propcache = ">=0.2.0"
|
|||||||
yarl = ">=1.17.0,<2.0"
|
yarl = ">=1.17.0,<2.0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"]
|
speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>=3.2.0) ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiosignal"
|
name = "aiosignal"
|
||||||
@@ -158,7 +158,7 @@ typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""}
|
|||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"]
|
doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"]
|
||||||
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"]
|
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\" and python_version < \"3.14\""]
|
||||||
trio = ["trio (>=0.26.1)"]
|
trio = ["trio (>=0.26.1)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -184,12 +184,62 @@ files = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||||
cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
cov = ["cloudpickle ; platform_python_implementation == \"CPython\"", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||||
dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
dev = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||||
docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
|
docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
|
||||||
tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||||
tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
|
tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "azure-common"
|
||||||
|
version = "1.1.28"
|
||||||
|
description = "Microsoft Azure Client Library for Python (Common)"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3"},
|
||||||
|
{file = "azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "azure-core"
|
||||||
|
version = "1.32.0"
|
||||||
|
description = "Microsoft Azure Core Library for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "azure_core-1.32.0-py3-none-any.whl", hash = "sha256:eac191a0efb23bfa83fddf321b27b122b4ec847befa3091fa736a5c32c50d7b4"},
|
||||||
|
{file = "azure_core-1.32.0.tar.gz", hash = "sha256:22b3c35d6b2dae14990f6c1be2912bf23ffe50b220e708a28ab1bb92b1c730e5"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
requests = ">=2.21.0"
|
||||||
|
six = ">=1.11.0"
|
||||||
|
typing-extensions = ">=4.6.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
aio = ["aiohttp (>=3.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "azure-search-documents"
|
||||||
|
version = "11.5.2"
|
||||||
|
description = "Microsoft Azure Cognitive Search Client Library for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "azure_search_documents-11.5.2-py3-none-any.whl", hash = "sha256:c949d011008a4b0bcee3db91132741b4e4d50ddb3f7e2f48944d949d4b413b11"},
|
||||||
|
{file = "azure_search_documents-11.5.2.tar.gz", hash = "sha256:98977dd1fa4978d3b7d8891a0856b3becb6f02cc07ff2e1ea40b9c7254ada315"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
azure-common = ">=1.1"
|
||||||
|
azure-core = ">=1.28.0"
|
||||||
|
isodate = ">=0.6.0"
|
||||||
|
typing-extensions = ">=4.6.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "backoff"
|
name = "backoff"
|
||||||
@@ -868,7 +918,7 @@ httpcore = "==1.*"
|
|||||||
idna = "*"
|
idna = "*"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
brotli = ["brotli", "brotlicffi"]
|
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
|
||||||
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||||
http2 = ["h2 (>=3,<5)"]
|
http2 = ["h2 (>=3,<5)"]
|
||||||
socks = ["socksio (==1.*)"]
|
socks = ["socksio (==1.*)"]
|
||||||
@@ -910,6 +960,18 @@ files = [
|
|||||||
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "isodate"
|
||||||
|
version = "0.7.2"
|
||||||
|
description = "An ISO 8601 date/time/duration parser and formatter"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15"},
|
||||||
|
{file = "isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "isort"
|
name = "isort"
|
||||||
version = "5.13.2"
|
version = "5.13.2"
|
||||||
@@ -1799,7 +1861,7 @@ typing-extensions = ">=4.12.2"
|
|||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
email = ["email-validator (>=2.0.0)"]
|
email = ["email-validator (>=2.0.0)"]
|
||||||
timezone = ["tzdata"]
|
timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pydantic-core"
|
name = "pydantic-core"
|
||||||
@@ -2213,13 +2275,13 @@ files = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"]
|
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
|
||||||
core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
|
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
|
||||||
cover = ["pytest-cov"]
|
cover = ["pytest-cov"]
|
||||||
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
|
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
|
||||||
enabler = ["pytest-enabler (>=2.2)"]
|
enabler = ["pytest-enabler (>=2.2)"]
|
||||||
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
|
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
|
||||||
type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"]
|
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "six"
|
name = "six"
|
||||||
@@ -2449,7 +2511,7 @@ files = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
|
brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""]
|
||||||
h2 = ["h2 (>=4,<5)"]
|
h2 = ["h2 (>=4,<5)"]
|
||||||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||||
zstd = ["zstandard (>=0.18.0)"]
|
zstd = ["zstandard (>=0.18.0)"]
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ sqlalchemy = "^2.0.31"
|
|||||||
langchain-neo4j = "^0.4.0"
|
langchain-neo4j = "^0.4.0"
|
||||||
neo4j = "^5.23.1"
|
neo4j = "^5.23.1"
|
||||||
rank-bm25 = "^0.2.2"
|
rank-bm25 = "^0.2.2"
|
||||||
|
azure-search-documents = "^11.5.0"
|
||||||
psycopg2-binary = "^2.9.10"
|
psycopg2-binary = "^2.9.10"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
|||||||
508
tests/vector_stores/test_azure_ai_search.py
Normal file
508
tests/vector_stores/test_azure_ai_search.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
import pytest
|
||||||
|
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError
|
||||||
|
|
||||||
|
# Import the AzureAISearch class and OutputData model from your module.
|
||||||
|
from mem0.vector_stores.azure_ai_search import AzureAISearch
|
||||||
|
|
||||||
|
|
||||||
|
# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch.
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_clients():
|
||||||
|
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
|
||||||
|
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient:
|
||||||
|
# Create mocked instances for search and index clients.
|
||||||
|
mock_search_client = MockSearchClient.return_value
|
||||||
|
mock_index_client = MockIndexClient.return_value
|
||||||
|
|
||||||
|
# Stub required methods on search_client.
|
||||||
|
mock_search_client.upload_documents = Mock()
|
||||||
|
mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}]
|
||||||
|
mock_search_client.search = Mock()
|
||||||
|
mock_search_client.delete_documents = Mock()
|
||||||
|
mock_search_client.delete_documents.return_value = [{"status": True, "id": "doc1"}]
|
||||||
|
mock_search_client.merge_or_upload_documents = Mock()
|
||||||
|
mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": "doc1"}]
|
||||||
|
mock_search_client.get_document = Mock()
|
||||||
|
mock_search_client.close = Mock()
|
||||||
|
|
||||||
|
# Stub required methods on index_client.
|
||||||
|
mock_index_client.create_or_update_index = Mock()
|
||||||
|
mock_index_client.list_indexes = Mock(return_value=[])
|
||||||
|
mock_index_client.list_index_names = Mock(return_value=["test-index"])
|
||||||
|
mock_index_client.delete_index = Mock()
|
||||||
|
# For col_info() we assume get_index returns an object with name and fields attributes.
|
||||||
|
fake_index = Mock()
|
||||||
|
fake_index.name = "test-index"
|
||||||
|
fake_index.fields = ["id", "vector", "payload", "user_id", "run_id", "agent_id"]
|
||||||
|
mock_index_client.get_index = Mock(return_value=fake_index)
|
||||||
|
mock_index_client.close = Mock()
|
||||||
|
|
||||||
|
yield mock_search_client, mock_index_client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def azure_ai_search_instance(mock_clients):
|
||||||
|
mock_search_client, mock_index_client = mock_clients
|
||||||
|
# Create an instance with dummy parameters.
|
||||||
|
instance = AzureAISearch(
|
||||||
|
service_name="test-service",
|
||||||
|
collection_name="test-index",
|
||||||
|
api_key="test-api-key",
|
||||||
|
embedding_model_dims=3,
|
||||||
|
compression_type="binary", # testing binary quantization option
|
||||||
|
use_float16=True
|
||||||
|
)
|
||||||
|
# Return instance and clients for verification.
|
||||||
|
return instance, mock_search_client, mock_index_client
|
||||||
|
|
||||||
|
# --- Original tests ---
|
||||||
|
|
||||||
|
def test_create_col(azure_ai_search_instance):
|
||||||
|
instance, mock_search_client, mock_index_client = azure_ai_search_instance
|
||||||
|
# Upon initialization, create_col should be called.
|
||||||
|
mock_index_client.create_or_update_index.assert_called_once()
|
||||||
|
# Optionally, you could inspect the call arguments for vector type.
|
||||||
|
|
||||||
|
def test_insert(azure_ai_search_instance):
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
|
payloads = [{"user_id": "user1", "run_id": "run1"}]
|
||||||
|
ids = ["doc1"]
|
||||||
|
|
||||||
|
instance.insert(vectors, payloads, ids)
|
||||||
|
|
||||||
|
mock_search_client.upload_documents.assert_called_once()
|
||||||
|
args, _ = mock_search_client.upload_documents.call_args
|
||||||
|
documents = args[0]
|
||||||
|
# Update expected_doc to include extra fields from payload.
|
||||||
|
expected_doc = {
|
||||||
|
"id": "doc1",
|
||||||
|
"vector": [0.1, 0.2, 0.3],
|
||||||
|
"payload": json.dumps({"user_id": "user1", "run_id": "run1"}),
|
||||||
|
"user_id": "user1",
|
||||||
|
"run_id": "run1"
|
||||||
|
}
|
||||||
|
assert documents[0] == expected_doc
|
||||||
|
|
||||||
|
def test_search_preFilter(azure_ai_search_instance):
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
# Setup a fake search result returned by the mocked search method.
|
||||||
|
fake_result = {
|
||||||
|
"id": "doc1",
|
||||||
|
"@search.score": 0.95,
|
||||||
|
"payload": json.dumps({"user_id": "user1"})
|
||||||
|
}
|
||||||
|
# Configure the mock to return an iterator (list) with fake_result.
|
||||||
|
mock_search_client.search.return_value = [fake_result]
|
||||||
|
|
||||||
|
query_vector = [0.1, 0.2, 0.3]
|
||||||
|
results = instance.search(query_vector, limit=1, filters={"user_id": "user1"}, vector_filter_mode="preFilter")
|
||||||
|
|
||||||
|
# Verify that the search method was called with vector_filter_mode="preFilter".
|
||||||
|
mock_search_client.search.assert_called_once()
|
||||||
|
_, called_kwargs = mock_search_client.search.call_args
|
||||||
|
assert called_kwargs.get("vector_filter_mode") == "preFilter"
|
||||||
|
|
||||||
|
# Verify that the output is parsed correctly.
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].id == "doc1"
|
||||||
|
assert results[0].score == 0.95
|
||||||
|
assert results[0].payload == {"user_id": "user1"}
|
||||||
|
|
||||||
|
def test_search_postFilter(azure_ai_search_instance):
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
# Setup a fake search result for postFilter.
|
||||||
|
fake_result = {
|
||||||
|
"id": "doc2",
|
||||||
|
"@search.score": 0.85,
|
||||||
|
"payload": json.dumps({"user_id": "user2"})
|
||||||
|
}
|
||||||
|
mock_search_client.search.return_value = [fake_result]
|
||||||
|
|
||||||
|
query_vector = [0.4, 0.5, 0.6]
|
||||||
|
results = instance.search(query_vector, limit=1, filters={"user_id": "user2"}, vector_filter_mode="postFilter")
|
||||||
|
|
||||||
|
mock_search_client.search.assert_called_once()
|
||||||
|
_, called_kwargs = mock_search_client.search.call_args
|
||||||
|
assert called_kwargs.get("vector_filter_mode") == "postFilter"
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].id == "doc2"
|
||||||
|
assert results[0].score == 0.85
|
||||||
|
assert results[0].payload == {"user_id": "user2"}
|
||||||
|
|
||||||
|
def test_delete(azure_ai_search_instance):
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
vector_id = "doc1"
|
||||||
|
# Set delete_documents to return an iterable with a successful response.
|
||||||
|
mock_search_client.delete_documents.return_value = [{"status": True, "id": vector_id}]
|
||||||
|
instance.delete(vector_id)
|
||||||
|
mock_search_client.delete_documents.assert_called_once_with(documents=[{"id": vector_id}])
|
||||||
|
|
||||||
|
def test_update(azure_ai_search_instance):
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
vector_id = "doc1"
|
||||||
|
new_vector = [0.7, 0.8, 0.9]
|
||||||
|
new_payload = {"user_id": "updated"}
|
||||||
|
# Set merge_or_upload_documents to return an iterable with a successful response.
|
||||||
|
mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": vector_id}]
|
||||||
|
instance.update(vector_id, vector=new_vector, payload=new_payload)
|
||||||
|
mock_search_client.merge_or_upload_documents.assert_called_once()
|
||||||
|
kwargs = mock_search_client.merge_or_upload_documents.call_args.kwargs
|
||||||
|
document = kwargs["documents"][0]
|
||||||
|
assert document["id"] == vector_id
|
||||||
|
assert document["vector"] == new_vector
|
||||||
|
assert document["payload"] == json.dumps(new_payload)
|
||||||
|
# The update method will also add the 'user_id' field.
|
||||||
|
assert document["user_id"] == "updated"
|
||||||
|
|
||||||
|
def test_get(azure_ai_search_instance):
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
fake_result = {
|
||||||
|
"id": "doc1",
|
||||||
|
"payload": json.dumps({"user_id": "user1"})
|
||||||
|
}
|
||||||
|
mock_search_client.get_document.return_value = fake_result
|
||||||
|
result = instance.get("doc1")
|
||||||
|
mock_search_client.get_document.assert_called_once_with(key="doc1")
|
||||||
|
assert result.id == "doc1"
|
||||||
|
assert result.payload == {"user_id": "user1"}
|
||||||
|
assert result.score is None
|
||||||
|
|
||||||
|
def test_list(azure_ai_search_instance):
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
fake_result = {
|
||||||
|
"id": "doc1",
|
||||||
|
"@search.score": 0.99,
|
||||||
|
"payload": json.dumps({"user_id": "user1"})
|
||||||
|
}
|
||||||
|
mock_search_client.search.return_value = [fake_result]
|
||||||
|
# Call list with a simple filter.
|
||||||
|
results = instance.list(filters={"user_id": "user1"}, limit=1)
|
||||||
|
# Verify the search method was called with the proper parameters.
|
||||||
|
expected_filter = instance._build_filter_expression({"user_id": "user1"})
|
||||||
|
mock_search_client.search.assert_called_once_with(
|
||||||
|
search_text="*",
|
||||||
|
filter=expected_filter,
|
||||||
|
top=1
|
||||||
|
)
|
||||||
|
assert isinstance(results, list)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].id == "doc1"
|
||||||
|
|
||||||
|
# --- New tests for practical end-user scenarios ---
|
||||||
|
|
||||||
|
def test_bulk_insert(azure_ai_search_instance):
|
||||||
|
"""Test inserting a batch of documents (common for initial data loading)."""
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
|
# Create a batch of 10 documents
|
||||||
|
num_docs = 10
|
||||||
|
vectors = [[0.1, 0.2, 0.3] for _ in range(num_docs)]
|
||||||
|
payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)]
|
||||||
|
ids = [f"doc{i}" for i in range(num_docs)]
|
||||||
|
|
||||||
|
# Configure mock to return success for all documents
|
||||||
|
mock_search_client.upload_documents.return_value = [
|
||||||
|
{"status": True, "id": id_val} for id_val in ids
|
||||||
|
]
|
||||||
|
|
||||||
|
# Insert the batch
|
||||||
|
instance.insert(vectors, payloads, ids)
|
||||||
|
|
||||||
|
# Verify the call
|
||||||
|
mock_search_client.upload_documents.assert_called_once()
|
||||||
|
args, _ = mock_search_client.upload_documents.call_args
|
||||||
|
documents = args[0]
|
||||||
|
assert len(documents) == num_docs
|
||||||
|
|
||||||
|
# Verify the first and last document
|
||||||
|
assert documents[0]["id"] == "doc0"
|
||||||
|
assert documents[-1]["id"] == f"doc{num_docs-1}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_error_handling(azure_ai_search_instance):
|
||||||
|
"""Test how the class handles Azure errors during insertion."""
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
|
# Configure mock to return a failure for one document
|
||||||
|
mock_search_client.upload_documents.return_value = [
|
||||||
|
{"status": False, "id": "doc1", "errorMessage": "Azure error"}
|
||||||
|
]
|
||||||
|
|
||||||
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
|
payloads = [{"user_id": "user1"}]
|
||||||
|
ids = ["doc1"]
|
||||||
|
|
||||||
|
# Exception should be raised
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
instance.insert(vectors, payloads, ids)
|
||||||
|
|
||||||
|
assert "Insert failed" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_with_complex_filters(azure_ai_search_instance):
|
||||||
|
"""Test searching with multiple filter conditions as a user might need."""
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
|
# Configure mock response
|
||||||
|
mock_search_client.search.return_value = [
|
||||||
|
{
|
||||||
|
"id": "doc1",
|
||||||
|
"@search.score": 0.95,
|
||||||
|
"payload": json.dumps({"user_id": "user1", "run_id": "run123", "agent_id": "agent456"})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Search with multiple filters (common in multi-tenant or segmented applications)
|
||||||
|
filters = {
|
||||||
|
"user_id": "user1",
|
||||||
|
"run_id": "run123",
|
||||||
|
"agent_id": "agent456"
|
||||||
|
}
|
||||||
|
results = instance.search([0.1, 0.2, 0.3], filters=filters)
|
||||||
|
|
||||||
|
# Verify search was called with the correct filter expression
|
||||||
|
mock_search_client.search.assert_called_once()
|
||||||
|
_, kwargs = mock_search_client.search.call_args
|
||||||
|
assert "filter" in kwargs
|
||||||
|
|
||||||
|
# The filter should contain all three conditions
|
||||||
|
filter_expr = kwargs["filter"]
|
||||||
|
assert "user_id eq 'user1'" in filter_expr
|
||||||
|
assert "run_id eq 'run123'" in filter_expr
|
||||||
|
assert "agent_id eq 'agent456'" in filter_expr
|
||||||
|
assert " and " in filter_expr # Conditions should be joined by AND
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_search_results(azure_ai_search_instance):
|
||||||
|
"""Test behavior when search returns no results (common edge case)."""
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
|
# Configure mock to return empty results
|
||||||
|
mock_search_client.search.return_value = []
|
||||||
|
|
||||||
|
# Search with a non-matching query
|
||||||
|
results = instance.search([0.9, 0.9, 0.9], limit=5)
|
||||||
|
|
||||||
|
# Verify result handling
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_nonexistent_document(azure_ai_search_instance):
|
||||||
|
"""Test behavior when getting a document that doesn't exist (should handle gracefully)."""
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
|
# Configure mock to raise ResourceNotFoundError
|
||||||
|
mock_search_client.get_document.side_effect = ResourceNotFoundError("Document not found")
|
||||||
|
|
||||||
|
# Get a non-existent document
|
||||||
|
result = instance.get("nonexistent_id")
|
||||||
|
|
||||||
|
# Should return None instead of raising exception
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_service_error(azure_ai_search_instance):
|
||||||
|
"""Test handling of Azure service errors (important for robustness)."""
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
|
# Configure mock to raise HttpResponseError
|
||||||
|
http_error = HttpResponseError("Azure service is unavailable")
|
||||||
|
mock_search_client.search.side_effect = http_error
|
||||||
|
|
||||||
|
# Attempt to search
|
||||||
|
with pytest.raises(HttpResponseError):
|
||||||
|
instance.search([0.1, 0.2, 0.3])
|
||||||
|
|
||||||
|
# Verify search was attempted
|
||||||
|
mock_search_client.search.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_realistic_workflow(azure_ai_search_instance):
|
||||||
|
"""Test a realistic workflow: insert → search → update → search again."""
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
|
# 1. Insert a document
|
||||||
|
vector = [0.1, 0.2, 0.3]
|
||||||
|
payload = {"user_id": "user1", "content": "Initial content"}
|
||||||
|
doc_id = "workflow_doc"
|
||||||
|
|
||||||
|
mock_search_client.upload_documents.return_value = [{"status": True, "id": doc_id}]
|
||||||
|
instance.insert([vector], [payload], [doc_id])
|
||||||
|
|
||||||
|
# 2. Search for the document
|
||||||
|
mock_search_client.search.return_value = [
|
||||||
|
{
|
||||||
|
"id": doc_id,
|
||||||
|
"@search.score": 0.95,
|
||||||
|
"payload": json.dumps(payload)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
results = instance.search(vector, filters={"user_id": "user1"})
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].id == doc_id
|
||||||
|
|
||||||
|
# 3. Update the document
|
||||||
|
updated_payload = {"user_id": "user1", "content": "Updated content"}
|
||||||
|
mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": doc_id}]
|
||||||
|
instance.update(doc_id, payload=updated_payload)
|
||||||
|
|
||||||
|
# 4. Search again to get updated document
|
||||||
|
mock_search_client.search.return_value = [
|
||||||
|
{
|
||||||
|
"id": doc_id,
|
||||||
|
"@search.score": 0.95,
|
||||||
|
"payload": json.dumps(updated_payload)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
results = instance.search(vector, filters={"user_id": "user1"})
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].id == doc_id
|
||||||
|
assert results[0].payload["content"] == "Updated content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_special_characters(azure_ai_search_instance):
|
||||||
|
"""Test that special characters in filter values are properly sanitized."""
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
|
# Configure mock response
|
||||||
|
mock_search_client.search.return_value = [
|
||||||
|
{
|
||||||
|
"id": "doc1",
|
||||||
|
"@search.score": 0.95,
|
||||||
|
"payload": json.dumps({"user_id": "user's-data"})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Search with a filter that has special characters (common in real-world data)
|
||||||
|
filters = {"user_id": "user's-data"}
|
||||||
|
results = instance.search([0.1, 0.2, 0.3], filters=filters)
|
||||||
|
|
||||||
|
# Verify search was called with properly escaped filter
|
||||||
|
mock_search_client.search.assert_called_once()
|
||||||
|
_, kwargs = mock_search_client.search.call_args
|
||||||
|
assert "filter" in kwargs
|
||||||
|
|
||||||
|
# The filter should have properly escaped single quotes
|
||||||
|
filter_expr = kwargs["filter"]
|
||||||
|
assert "user_id eq 'user''s-data'" in filter_expr
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_collections(azure_ai_search_instance):
|
||||||
|
"""Test listing all collections/indexes (for management interfaces)."""
|
||||||
|
instance, _, mock_index_client = azure_ai_search_instance
|
||||||
|
|
||||||
|
# List the collections
|
||||||
|
collections = instance.list_cols()
|
||||||
|
|
||||||
|
# Verify the correct method was called
|
||||||
|
mock_index_client.list_index_names.assert_called_once()
|
||||||
|
|
||||||
|
# Check the result
|
||||||
|
assert collections == ["test-index"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_with_numeric_values(azure_ai_search_instance):
|
||||||
|
"""Test filtering with numeric values (common for faceted search)."""
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
|
# Configure mock response
|
||||||
|
mock_search_client.search.return_value = [
|
||||||
|
{
|
||||||
|
"id": "doc1",
|
||||||
|
"@search.score": 0.95,
|
||||||
|
"payload": json.dumps({"user_id": "user1", "count": 42})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Search with a numeric filter
|
||||||
|
# Note: In the actual implementation, numeric fields might need to be in the payload
|
||||||
|
filters = {"count": 42}
|
||||||
|
results = instance.search([0.1, 0.2, 0.3], filters=filters)
|
||||||
|
|
||||||
|
# Verify the filter expression
|
||||||
|
mock_search_client.search.assert_called_once()
|
||||||
|
_, kwargs = mock_search_client.search.call_args
|
||||||
|
filter_expr = kwargs["filter"]
|
||||||
|
assert "count eq 42" in filter_expr # No quotes for numbers
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_on_update_nonexistent(azure_ai_search_instance):
|
||||||
|
"""Test behavior when updating a document that doesn't exist."""
|
||||||
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
|
# Configure mock to return a failure for the update
|
||||||
|
mock_search_client.merge_or_upload_documents.return_value = [
|
||||||
|
{"status": False, "id": "nonexistent", "errorMessage": "Document not found"}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Attempt to update a non-existent document
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
instance.update("nonexistent", payload={"new": "data"})
|
||||||
|
|
||||||
|
assert "Update failed" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_compression_types():
|
||||||
|
"""Test creating instances with different compression types (important for performance tuning)."""
|
||||||
|
with patch("mem0.vector_stores.azure_ai_search.SearchClient"), \
|
||||||
|
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"):
|
||||||
|
|
||||||
|
# Test with scalar compression
|
||||||
|
scalar_instance = AzureAISearch(
|
||||||
|
service_name="test-service",
|
||||||
|
collection_name="scalar-index",
|
||||||
|
api_key="test-api-key",
|
||||||
|
embedding_model_dims=3,
|
||||||
|
compression_type="scalar",
|
||||||
|
use_float16=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with no compression
|
||||||
|
no_compression_instance = AzureAISearch(
|
||||||
|
service_name="test-service",
|
||||||
|
collection_name="no-compression-index",
|
||||||
|
api_key="test-api-key",
|
||||||
|
embedding_model_dims=3,
|
||||||
|
compression_type=None,
|
||||||
|
use_float16=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# No assertions needed - we're just verifying that initialization doesn't fail
|
||||||
|
|
||||||
|
|
||||||
|
def test_high_dimensional_vectors():
|
||||||
|
"""Test handling of high-dimensional vectors typical in AI embeddings."""
|
||||||
|
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
|
||||||
|
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"):
|
||||||
|
|
||||||
|
# Configure the mock client
|
||||||
|
mock_search_client = MockSearchClient.return_value
|
||||||
|
mock_search_client.upload_documents = Mock()
|
||||||
|
mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}]
|
||||||
|
|
||||||
|
# Create an instance with higher dimensions like those from embedding models
|
||||||
|
high_dim_instance = AzureAISearch(
|
||||||
|
service_name="test-service",
|
||||||
|
collection_name="high-dim-index",
|
||||||
|
api_key="test-api-key",
|
||||||
|
embedding_model_dims=1536, # Common for models like OpenAI's embeddings
|
||||||
|
compression_type="binary", # Compression often used with high-dim vectors
|
||||||
|
use_float16=True # Reduced precision often used for memory efficiency
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a high-dimensional vector (stub with zeros for testing)
|
||||||
|
high_dim_vector = [0.0] * 1536
|
||||||
|
payload = {"user_id": "user1"}
|
||||||
|
doc_id = "high_dim_doc"
|
||||||
|
|
||||||
|
# Insert the document
|
||||||
|
high_dim_instance.insert([high_dim_vector], [payload], [doc_id])
|
||||||
|
|
||||||
|
# Verify the insert was called with the full vector
|
||||||
|
mock_search_client.upload_documents.assert_called_once()
|
||||||
|
args, _ = mock_search_client.upload_documents.call_args
|
||||||
|
documents = args[0]
|
||||||
|
assert len(documents[0]["vector"]) == 1536
|
||||||
Reference in New Issue
Block a user