feat: enhance Azure AI Search Integration with Binary Quantization, Pre/Post Filter Options, and user agent header (#2354)
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
[Azure AI Search](https://learn.microsoft.com/en-us/azure/search/search-what-is-azure-search/) (formerly known as "Azure Cognitive Search") provides secure information retrieval at scale over user-owned content in traditional and generative AI search applications.
|
||||
# Azure AI Search
|
||||
|
||||
### Usage
|
||||
[Azure AI Search](https://learn.microsoft.com/azure/search/search-what-is-azure-search/) (formerly known as "Azure Cognitive Search") provides secure information retrieval at scale over user-owned content in traditional and generative AI search applications.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xx" #this key is used for embedding purpose
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xx" # This key is used for embedding purpose
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
@@ -15,8 +17,8 @@ config = {
|
||||
"service_name": "ai-search-test",
|
||||
"api_key": "*****",
|
||||
"collection_name": "mem0",
|
||||
"embedding_model_dims": 1536 ,
|
||||
"use_compression": False
|
||||
"embedding_model_dims": 1536,
|
||||
"compression_type": "none"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -25,20 +27,61 @@ m = Memory.from_config(config)
|
||||
messages = [
|
||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||
{"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."},
|
||||
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||
]
|
||||
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||
```
|
||||
|
||||
### Config
|
||||
## Advanced Usage
|
||||
|
||||
Let's see the available parameters for the `qdrant` config:
|
||||
service_name (str): Azure Cognitive Search service name.
|
||||
| Parameter | Description | Default Value |
|
||||
| --- | --- | --- |
|
||||
| `service_name` | Azure AI Search service name | `None` |
|
||||
| `api_key` | API key of the Azure AI Search service | `None` |
|
||||
| `collection_name` | The name of the collection/index to store the vectors, it will be created automatically if not exist | `mem0` |
|
||||
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
||||
| `use_compression` | Use scalar quantization vector compression | False |
|
||||
```python
|
||||
# Search with specific filter mode
|
||||
result = m.search(
|
||||
"sci-fi movies",
|
||||
filters={"user_id": "alice"},
|
||||
limit=5,
|
||||
vector_filter_mode="preFilter" # Apply filters before vector search
|
||||
)
|
||||
|
||||
# Using binary compression for large vector collections
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "azure_ai_search",
|
||||
"config": {
|
||||
"service_name": "ai-search-test",
|
||||
"api_key": "*****",
|
||||
"collection_name": "mem0",
|
||||
"embedding_model_dims": 1536,
|
||||
"compression_type": "binary",
|
||||
"use_float16": True # Use half precision for storage efficiency
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Parameters
|
||||
|
||||
| Parameter | Description | Default Value | Options |
|
||||
| --- | --- | --- | --- |
|
||||
| `service_name` | Azure AI Search service name | Required | - |
|
||||
| `api_key` | API key of the Azure AI Search service | Required | - |
|
||||
| `collection_name` | The name of the collection/index to store vectors | `mem0` | Any valid index name |
|
||||
| `embedding_model_dims` | Dimensions of the embedding model | `1536` | Any integer value |
|
||||
| `compression_type` | Type of vector compression to use | `none` | `none`, `scalar`, `binary` |
|
||||
| `use_float16` | Store vectors in half precision (Edm.Half) | `False` | `True`, `False` |
|
||||
|
||||
## Notes on Configuration Options
|
||||
|
||||
- **compression_type**:
|
||||
- `none`: No compression, uses full vector precision
|
||||
- `scalar`: Scalar quantization with reasonable balance of speed and accuracy
|
||||
- `binary`: Binary quantization for maximum compression with some accuracy trade-off
|
||||
|
||||
- **vector_filter_mode**:
|
||||
- `preFilter`: Applies filters before vector search (faster)
|
||||
- `postFilter`: Applies filters after vector search (may provide better relevance)
|
||||
|
||||
- **use_float16**: Using half precision (float16) reduces storage requirements but may slightly impact accuracy. Useful for very large vector collections.
|
||||
|
||||
- **Filterable Fields**: The implementation automatically extracts `user_id`, `run_id`, and `agent_id` fields from payloads for filtering.
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -12,6 +13,7 @@ try:
|
||||
from azure.search.documents import SearchClient
|
||||
from azure.search.documents.indexes import SearchIndexClient
|
||||
from azure.search.documents.indexes.models import (
|
||||
BinaryQuantizationCompression,
|
||||
HnswAlgorithmConfiguration,
|
||||
ScalarQuantizationCompression,
|
||||
SearchField,
|
||||
@@ -24,7 +26,7 @@ try:
|
||||
from azure.search.documents.models import VectorizedQuery
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.1'."
|
||||
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'."
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,43 +39,82 @@ class OutputData(BaseModel):
|
||||
|
||||
|
||||
class AzureAISearch(VectorStoreBase):
|
||||
def __init__(self, service_name, collection_name, api_key, embedding_model_dims, use_compression):
|
||||
"""Initialize the Azure Cognitive Search vector store.
|
||||
def __init__(
|
||||
self,
|
||||
service_name,
|
||||
collection_name,
|
||||
api_key,
|
||||
embedding_model_dims,
|
||||
compression_type: Optional[str] = None,
|
||||
use_float16: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the Azure AI Search vector store.
|
||||
|
||||
Args:
|
||||
service_name (str): Azure Cognitive Search service name.
|
||||
service_name (str): Azure AI Search service name.
|
||||
collection_name (str): Index name.
|
||||
api_key (str): API key for the Azure Cognitive Search service.
|
||||
api_key (str): API key for the Azure AI Search service.
|
||||
embedding_model_dims (int): Dimension of the embedding vector.
|
||||
use_compression (bool): Use scalar quantization vector compression
|
||||
compression_type (Optional[str]): Specifies the type of quantization to use.
|
||||
Allowed values are None (no quantization), "scalar", or "binary".
|
||||
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
|
||||
(Note: This flag is preserved from the initial implementation per feedback.)
|
||||
"""
|
||||
self.index_name = collection_name
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.use_compression = use_compression
|
||||
# If compression_type is None, treat it as "none".
|
||||
self.compression_type = (compression_type or "none").lower()
|
||||
self.use_float16 = use_float16
|
||||
|
||||
self.search_client = SearchClient(
|
||||
endpoint=f"https://{service_name}.search.windows.net",
|
||||
index_name=self.index_name,
|
||||
credential=AzureKeyCredential(api_key),
|
||||
)
|
||||
self.index_client = SearchIndexClient(
|
||||
endpoint=f"https://{service_name}.search.windows.net", credential=AzureKeyCredential(api_key)
|
||||
endpoint=f"https://{service_name}.search.windows.net",
|
||||
credential=AzureKeyCredential(api_key),
|
||||
)
|
||||
|
||||
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
|
||||
self.create_col() # create the collection / index
|
||||
|
||||
def create_col(self):
|
||||
"""Create a new index in Azure Cognitive Search."""
|
||||
vector_dimensions = self.embedding_model_dims # Set this to the number of dimensions in your vector
|
||||
|
||||
if self.use_compression:
|
||||
"""Create a new index in Azure AI Search."""
|
||||
# Determine vector type based on use_float16 setting.
|
||||
if self.use_float16:
|
||||
vector_type = "Collection(Edm.Half)"
|
||||
compression_name = "myCompression"
|
||||
compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]
|
||||
else:
|
||||
vector_type = "Collection(Edm.Single)"
|
||||
compression_name = None
|
||||
compression_configurations = []
|
||||
|
||||
# Configure compression settings based on the specified compression_type.
|
||||
compression_configurations = []
|
||||
compression_name = None
|
||||
if self.compression_type == "scalar":
|
||||
compression_name = "myCompression"
|
||||
# For SQ, rescoring defaults to True and oversampling defaults to 4.
|
||||
compression_configurations = [
|
||||
ScalarQuantizationCompression(
|
||||
compression_name=compression_name
|
||||
# rescoring defaults to True and oversampling defaults to 4
|
||||
)
|
||||
]
|
||||
elif self.compression_type == "binary":
|
||||
compression_name = "myCompression"
|
||||
# For BQ, rescoring defaults to True and oversampling defaults to 10.
|
||||
compression_configurations = [
|
||||
BinaryQuantizationCompression(
|
||||
compression_name=compression_name
|
||||
# rescoring defaults to True and oversampling defaults to 10
|
||||
)
|
||||
]
|
||||
# If no compression is desired, compression_configurations remains empty.
|
||||
|
||||
|
||||
fields = [
|
||||
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
||||
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
|
||||
@@ -82,8 +123,8 @@ class AzureAISearch(VectorStoreBase):
|
||||
SearchField(
|
||||
name="vector",
|
||||
type=vector_type,
|
||||
searchable=True,
|
||||
vector_search_dimensions=vector_dimensions,
|
||||
searchable=True,
|
||||
vector_search_dimensions=self.embedding_model_dims,
|
||||
vector_search_profile_name="my-vector-config",
|
||||
),
|
||||
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
||||
@@ -91,7 +132,11 @@ class AzureAISearch(VectorStoreBase):
|
||||
|
||||
vector_search = VectorSearch(
|
||||
profiles=[
|
||||
VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")
|
||||
VectorSearchProfile(
|
||||
name="my-vector-config",
|
||||
algorithm_configuration_name="my-algorithms-config",
|
||||
compression_name=compression_name if self.compression_type != "none" else None
|
||||
)
|
||||
],
|
||||
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
|
||||
compressions=compression_configurations,
|
||||
@@ -101,14 +146,16 @@ class AzureAISearch(VectorStoreBase):
|
||||
|
||||
def _generate_document(self, vector, payload, id):
|
||||
document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
|
||||
# Extract additional fields if they exist
|
||||
# Extract additional fields if they exist.
|
||||
for field in ["user_id", "run_id", "agent_id"]:
|
||||
if field in payload:
|
||||
document[field] = payload[field]
|
||||
return document
|
||||
|
||||
# Note: Explicit "insert" calls may later be decoupled from memory management decisions.
|
||||
def insert(self, vectors, payloads=None, ids=None):
|
||||
"""Insert vectors into the index.
|
||||
"""
|
||||
Insert vectors into the index.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert.
|
||||
@@ -116,61 +163,87 @@ class AzureAISearch(VectorStoreBase):
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
||||
|
||||
documents = [
|
||||
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
|
||||
self._generate_document(vector, payload, id)
|
||||
for id, vector, payload in zip(ids, vectors, payloads)
|
||||
]
|
||||
self.search_client.upload_documents(documents)
|
||||
response = self.search_client.upload_documents(documents)
|
||||
for doc in response:
|
||||
if not doc.get("status", False):
|
||||
raise Exception(f"Insert failed for document {doc.get('id')}: {doc}")
|
||||
return response
|
||||
|
||||
def _sanitize_key(self, key: str) -> str:
|
||||
return re.sub(r"[^\w]", "", key)
|
||||
|
||||
def _build_filter_expression(self, filters):
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
# If the value is a string, add quotes
|
||||
safe_key = self._sanitize_key(key)
|
||||
if isinstance(value, str):
|
||||
condition = f"{key} eq '{value}'"
|
||||
safe_value = value.replace("'", "''")
|
||||
condition = f"{safe_key} eq '{safe_value}'"
|
||||
else:
|
||||
condition = f"{key} eq {value}"
|
||||
condition = f"{safe_key} eq {value}"
|
||||
filter_conditions.append(condition)
|
||||
# Use 'and' to join multiple conditions
|
||||
filter_expression = " and ".join(filter_conditions)
|
||||
return filter_expression
|
||||
|
||||
def search(self, query, limit=5, filters=None):
|
||||
"""Search for similar vectors.
|
||||
def search(self, query, limit=5, filters=None, vector_filter_mode="preFilter"):
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (List[float]): Query vectors.
|
||||
query (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
vector_filter_mode (str): Determines whether filters are applied before or after the vector search.
|
||||
Known values: "preFilter" (default) and "postFilter".
|
||||
|
||||
Returns:
|
||||
list: Search results.
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
# Build filter expression
|
||||
filter_expression = None
|
||||
if filters:
|
||||
filter_expression = self._build_filter_expression(filters)
|
||||
|
||||
vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector")
|
||||
search_results = self.search_client.search(vector_queries=[vector_query], filter=filter_expression, top=limit)
|
||||
vector_query = VectorizedQuery(
|
||||
vector=query, k_nearest_neighbors=limit, fields="vector"
|
||||
)
|
||||
search_results = self.search_client.search(
|
||||
vector_queries=[vector_query],
|
||||
filter=filter_expression,
|
||||
top=limit,
|
||||
vector_filter_mode=vector_filter_mode,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in search_results:
|
||||
payload = json.loads(result["payload"])
|
||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||
results.append(
|
||||
OutputData(
|
||||
id=result["id"], score=result["@search.score"], payload=payload
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""Delete a vector by ID.
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
self.search_client.delete_documents(documents=[{"id": vector_id}])
|
||||
response = self.search_client.delete_documents(documents=[{"id": vector_id}])
|
||||
for doc in response:
|
||||
if not doc.get("status", False):
|
||||
raise Exception(f"Delete failed for document {vector_id}: {doc}")
|
||||
logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
|
||||
return response
|
||||
|
||||
def update(self, vector_id, vector=None, payload=None):
|
||||
"""Update a vector and its payload.
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
@@ -185,10 +258,15 @@ class AzureAISearch(VectorStoreBase):
|
||||
document["payload"] = json_payload
|
||||
for field in ["user_id", "run_id", "agent_id"]:
|
||||
document[field] = payload.get(field)
|
||||
self.search_client.merge_or_upload_documents(documents=[document])
|
||||
response = self.search_client.merge_or_upload_documents(documents=[document])
|
||||
for doc in response:
|
||||
if not doc.get("status", False):
|
||||
raise Exception(f"Update failed for document {vector_id}: {doc}")
|
||||
return response
|
||||
|
||||
def get(self, vector_id) -> OutputData:
|
||||
"""Retrieve a vector by ID.
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
@@ -200,35 +278,43 @@ class AzureAISearch(VectorStoreBase):
|
||||
result = self.search_client.get_document(key=vector_id)
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
|
||||
return OutputData(
|
||||
id=result["id"], score=None, payload=json.loads(result["payload"])
|
||||
)
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""List all collections (indexes).
|
||||
"""
|
||||
List all collections (indexes).
|
||||
|
||||
Returns:
|
||||
List[str]: List of index names.
|
||||
"""
|
||||
indexes = self.index_client.list_indexes()
|
||||
return [index.name for index in indexes]
|
||||
try:
|
||||
names = self.index_client.list_index_names()
|
||||
except AttributeError:
|
||||
names = [index.name for index in self.index_client.list_indexes()]
|
||||
return names
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete the index."""
|
||||
self.index_client.delete_index(self.index_name)
|
||||
|
||||
def col_info(self):
|
||||
"""Get information about the index.
|
||||
"""
|
||||
Get information about the index.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Index information.
|
||||
dict: Index information.
|
||||
"""
|
||||
index = self.index_client.get_index(self.index_name)
|
||||
return {"name": index.name, "fields": index.fields}
|
||||
|
||||
def list(self, filters=None, limit=100):
|
||||
"""List all vectors in the index.
|
||||
"""
|
||||
List all vectors in the index.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply to the list.
|
||||
filters (dict, optional): Filters to apply to the list.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
@@ -238,13 +324,18 @@ class AzureAISearch(VectorStoreBase):
|
||||
if filters:
|
||||
filter_expression = self._build_filter_expression(filters)
|
||||
|
||||
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
|
||||
search_results = self.search_client.search(
|
||||
search_text="*", filter=filter_expression, top=limit
|
||||
)
|
||||
results = []
|
||||
for result in search_results:
|
||||
payload = json.loads(result["payload"])
|
||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||
|
||||
return [results]
|
||||
results.append(
|
||||
OutputData(
|
||||
id=result["id"], score=result["@search.score"], payload=payload
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
def __del__(self):
|
||||
"""Close the search client when the object is deleted."""
|
||||
|
||||
90
poetry.lock
generated
90
poetry.lock
generated
@@ -112,7 +112,7 @@ propcache = ">=0.2.0"
|
||||
yarl = ">=1.17.0,<2.0"
|
||||
|
||||
[package.extras]
|
||||
speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"]
|
||||
speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>=3.2.0) ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
|
||||
|
||||
[[package]]
|
||||
name = "aiosignal"
|
||||
@@ -158,7 +158,7 @@ typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""}
|
||||
|
||||
[package.extras]
|
||||
doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"]
|
||||
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"]
|
||||
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\" and python_version < \"3.14\""]
|
||||
trio = ["trio (>=0.26.1)"]
|
||||
|
||||
[[package]]
|
||||
@@ -184,12 +184,62 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
||||
cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
||||
dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
||||
benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||
cov = ["cloudpickle ; platform_python_implementation == \"CPython\"", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||
dev = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||
docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
|
||||
tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
||||
tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
|
||||
tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
|
||||
tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""]
|
||||
|
||||
[[package]]
|
||||
name = "azure-common"
|
||||
version = "1.1.28"
|
||||
description = "Microsoft Azure Client Library for Python (Common)"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3"},
|
||||
{file = "azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "azure-core"
|
||||
version = "1.32.0"
|
||||
description = "Microsoft Azure Core Library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "azure_core-1.32.0-py3-none-any.whl", hash = "sha256:eac191a0efb23bfa83fddf321b27b122b4ec847befa3091fa736a5c32c50d7b4"},
|
||||
{file = "azure_core-1.32.0.tar.gz", hash = "sha256:22b3c35d6b2dae14990f6c1be2912bf23ffe50b220e708a28ab1bb92b1c730e5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
requests = ">=2.21.0"
|
||||
six = ">=1.11.0"
|
||||
typing-extensions = ">=4.6.0"
|
||||
|
||||
[package.extras]
|
||||
aio = ["aiohttp (>=3.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "azure-search-documents"
|
||||
version = "11.5.2"
|
||||
description = "Microsoft Azure Cognitive Search Client Library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "azure_search_documents-11.5.2-py3-none-any.whl", hash = "sha256:c949d011008a4b0bcee3db91132741b4e4d50ddb3f7e2f48944d949d4b413b11"},
|
||||
{file = "azure_search_documents-11.5.2.tar.gz", hash = "sha256:98977dd1fa4978d3b7d8891a0856b3becb6f02cc07ff2e1ea40b9c7254ada315"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
azure-common = ">=1.1"
|
||||
azure-core = ">=1.28.0"
|
||||
isodate = ">=0.6.0"
|
||||
typing-extensions = ">=4.6.0"
|
||||
|
||||
[[package]]
|
||||
name = "backoff"
|
||||
@@ -868,7 +918,7 @@ httpcore = "==1.*"
|
||||
idna = "*"
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli", "brotlicffi"]
|
||||
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
|
||||
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
@@ -910,6 +960,18 @@ files = [
|
||||
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "isodate"
|
||||
version = "0.7.2"
|
||||
description = "An ISO 8601 date/time/duration parser and formatter"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15"},
|
||||
{file = "isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "isort"
|
||||
version = "5.13.2"
|
||||
@@ -1799,7 +1861,7 @@ typing-extensions = ">=4.12.2"
|
||||
|
||||
[package.extras]
|
||||
email = ["email-validator (>=2.0.0)"]
|
||||
timezone = ["tzdata"]
|
||||
timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-core"
|
||||
@@ -2213,13 +2275,13 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"]
|
||||
core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
|
||||
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
|
||||
cover = ["pytest-cov"]
|
||||
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
|
||||
enabler = ["pytest-enabler (>=2.2)"]
|
||||
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
|
||||
type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"]
|
||||
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
|
||||
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "six"
|
||||
@@ -2449,7 +2511,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
|
||||
brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""]
|
||||
h2 = ["h2 (>=4,<5)"]
|
||||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
@@ -25,6 +25,7 @@ sqlalchemy = "^2.0.31"
|
||||
langchain-neo4j = "^0.4.0"
|
||||
neo4j = "^5.23.1"
|
||||
rank-bm25 = "^0.2.2"
|
||||
azure-search-documents = "^11.5.0"
|
||||
psycopg2-binary = "^2.9.10"
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
||||
508
tests/vector_stores/test_azure_ai_search.py
Normal file
508
tests/vector_stores/test_azure_ai_search.py
Normal file
@@ -0,0 +1,508 @@
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError
|
||||
|
||||
# Import the AzureAISearch class and OutputData model from your module.
|
||||
from mem0.vector_stores.azure_ai_search import AzureAISearch
|
||||
|
||||
|
||||
# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch.
|
||||
@pytest.fixture
|
||||
def mock_clients():
|
||||
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
|
||||
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient:
|
||||
# Create mocked instances for search and index clients.
|
||||
mock_search_client = MockSearchClient.return_value
|
||||
mock_index_client = MockIndexClient.return_value
|
||||
|
||||
# Stub required methods on search_client.
|
||||
mock_search_client.upload_documents = Mock()
|
||||
mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}]
|
||||
mock_search_client.search = Mock()
|
||||
mock_search_client.delete_documents = Mock()
|
||||
mock_search_client.delete_documents.return_value = [{"status": True, "id": "doc1"}]
|
||||
mock_search_client.merge_or_upload_documents = Mock()
|
||||
mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": "doc1"}]
|
||||
mock_search_client.get_document = Mock()
|
||||
mock_search_client.close = Mock()
|
||||
|
||||
# Stub required methods on index_client.
|
||||
mock_index_client.create_or_update_index = Mock()
|
||||
mock_index_client.list_indexes = Mock(return_value=[])
|
||||
mock_index_client.list_index_names = Mock(return_value=["test-index"])
|
||||
mock_index_client.delete_index = Mock()
|
||||
# For col_info() we assume get_index returns an object with name and fields attributes.
|
||||
fake_index = Mock()
|
||||
fake_index.name = "test-index"
|
||||
fake_index.fields = ["id", "vector", "payload", "user_id", "run_id", "agent_id"]
|
||||
mock_index_client.get_index = Mock(return_value=fake_index)
|
||||
mock_index_client.close = Mock()
|
||||
|
||||
yield mock_search_client, mock_index_client
|
||||
|
||||
@pytest.fixture
|
||||
def azure_ai_search_instance(mock_clients):
|
||||
mock_search_client, mock_index_client = mock_clients
|
||||
# Create an instance with dummy parameters.
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="test-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=3,
|
||||
compression_type="binary", # testing binary quantization option
|
||||
use_float16=True
|
||||
)
|
||||
# Return instance and clients for verification.
|
||||
return instance, mock_search_client, mock_index_client
|
||||
|
||||
# --- Original tests ---
|
||||
|
||||
def test_create_col(azure_ai_search_instance):
|
||||
instance, mock_search_client, mock_index_client = azure_ai_search_instance
|
||||
# Upon initialization, create_col should be called.
|
||||
mock_index_client.create_or_update_index.assert_called_once()
|
||||
# Optionally, you could inspect the call arguments for vector type.
|
||||
|
||||
def test_insert(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
payloads = [{"user_id": "user1", "run_id": "run1"}]
|
||||
ids = ["doc1"]
|
||||
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
mock_search_client.upload_documents.assert_called_once()
|
||||
args, _ = mock_search_client.upload_documents.call_args
|
||||
documents = args[0]
|
||||
# Update expected_doc to include extra fields from payload.
|
||||
expected_doc = {
|
||||
"id": "doc1",
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
"payload": json.dumps({"user_id": "user1", "run_id": "run1"}),
|
||||
"user_id": "user1",
|
||||
"run_id": "run1"
|
||||
}
|
||||
assert documents[0] == expected_doc
|
||||
|
||||
def test_search_preFilter(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
# Setup a fake search result returned by the mocked search method.
|
||||
fake_result = {
|
||||
"id": "doc1",
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps({"user_id": "user1"})
|
||||
}
|
||||
# Configure the mock to return an iterator (list) with fake_result.
|
||||
mock_search_client.search.return_value = [fake_result]
|
||||
|
||||
query_vector = [0.1, 0.2, 0.3]
|
||||
results = instance.search(query_vector, limit=1, filters={"user_id": "user1"}, vector_filter_mode="preFilter")
|
||||
|
||||
# Verify that the search method was called with vector_filter_mode="preFilter".
|
||||
mock_search_client.search.assert_called_once()
|
||||
_, called_kwargs = mock_search_client.search.call_args
|
||||
assert called_kwargs.get("vector_filter_mode") == "preFilter"
|
||||
|
||||
# Verify that the output is parsed correctly.
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "doc1"
|
||||
assert results[0].score == 0.95
|
||||
assert results[0].payload == {"user_id": "user1"}
|
||||
|
||||
def test_search_postFilter(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
# Setup a fake search result for postFilter.
|
||||
fake_result = {
|
||||
"id": "doc2",
|
||||
"@search.score": 0.85,
|
||||
"payload": json.dumps({"user_id": "user2"})
|
||||
}
|
||||
mock_search_client.search.return_value = [fake_result]
|
||||
|
||||
query_vector = [0.4, 0.5, 0.6]
|
||||
results = instance.search(query_vector, limit=1, filters={"user_id": "user2"}, vector_filter_mode="postFilter")
|
||||
|
||||
mock_search_client.search.assert_called_once()
|
||||
_, called_kwargs = mock_search_client.search.call_args
|
||||
assert called_kwargs.get("vector_filter_mode") == "postFilter"
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "doc2"
|
||||
assert results[0].score == 0.85
|
||||
assert results[0].payload == {"user_id": "user2"}
|
||||
|
||||
def test_delete(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
vector_id = "doc1"
|
||||
# Set delete_documents to return an iterable with a successful response.
|
||||
mock_search_client.delete_documents.return_value = [{"status": True, "id": vector_id}]
|
||||
instance.delete(vector_id)
|
||||
mock_search_client.delete_documents.assert_called_once_with(documents=[{"id": vector_id}])
|
||||
|
||||
def test_update(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
vector_id = "doc1"
|
||||
new_vector = [0.7, 0.8, 0.9]
|
||||
new_payload = {"user_id": "updated"}
|
||||
# Set merge_or_upload_documents to return an iterable with a successful response.
|
||||
mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": vector_id}]
|
||||
instance.update(vector_id, vector=new_vector, payload=new_payload)
|
||||
mock_search_client.merge_or_upload_documents.assert_called_once()
|
||||
kwargs = mock_search_client.merge_or_upload_documents.call_args.kwargs
|
||||
document = kwargs["documents"][0]
|
||||
assert document["id"] == vector_id
|
||||
assert document["vector"] == new_vector
|
||||
assert document["payload"] == json.dumps(new_payload)
|
||||
# The update method will also add the 'user_id' field.
|
||||
assert document["user_id"] == "updated"
|
||||
|
||||
def test_get(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
fake_result = {
|
||||
"id": "doc1",
|
||||
"payload": json.dumps({"user_id": "user1"})
|
||||
}
|
||||
mock_search_client.get_document.return_value = fake_result
|
||||
result = instance.get("doc1")
|
||||
mock_search_client.get_document.assert_called_once_with(key="doc1")
|
||||
assert result.id == "doc1"
|
||||
assert result.payload == {"user_id": "user1"}
|
||||
assert result.score is None
|
||||
|
||||
def test_list(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
fake_result = {
|
||||
"id": "doc1",
|
||||
"@search.score": 0.99,
|
||||
"payload": json.dumps({"user_id": "user1"})
|
||||
}
|
||||
mock_search_client.search.return_value = [fake_result]
|
||||
# Call list with a simple filter.
|
||||
results = instance.list(filters={"user_id": "user1"}, limit=1)
|
||||
# Verify the search method was called with the proper parameters.
|
||||
expected_filter = instance._build_filter_expression({"user_id": "user1"})
|
||||
mock_search_client.search.assert_called_once_with(
|
||||
search_text="*",
|
||||
filter=expected_filter,
|
||||
top=1
|
||||
)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "doc1"
|
||||
|
||||
# --- New tests for practical end-user scenarios ---
|
||||
|
||||
def test_bulk_insert(azure_ai_search_instance):
|
||||
"""Test inserting a batch of documents (common for initial data loading)."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Create a batch of 10 documents
|
||||
num_docs = 10
|
||||
vectors = [[0.1, 0.2, 0.3] for _ in range(num_docs)]
|
||||
payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)]
|
||||
ids = [f"doc{i}" for i in range(num_docs)]
|
||||
|
||||
# Configure mock to return success for all documents
|
||||
mock_search_client.upload_documents.return_value = [
|
||||
{"status": True, "id": id_val} for id_val in ids
|
||||
]
|
||||
|
||||
# Insert the batch
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
# Verify the call
|
||||
mock_search_client.upload_documents.assert_called_once()
|
||||
args, _ = mock_search_client.upload_documents.call_args
|
||||
documents = args[0]
|
||||
assert len(documents) == num_docs
|
||||
|
||||
# Verify the first and last document
|
||||
assert documents[0]["id"] == "doc0"
|
||||
assert documents[-1]["id"] == f"doc{num_docs-1}"
|
||||
|
||||
|
||||
def test_insert_error_handling(azure_ai_search_instance):
|
||||
"""Test how the class handles Azure errors during insertion."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to return a failure for one document
|
||||
mock_search_client.upload_documents.return_value = [
|
||||
{"status": False, "id": "doc1", "errorMessage": "Azure error"}
|
||||
]
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
payloads = [{"user_id": "user1"}]
|
||||
ids = ["doc1"]
|
||||
|
||||
# Exception should be raised
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
assert "Insert failed" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_search_with_complex_filters(azure_ai_search_instance):
|
||||
"""Test searching with multiple filter conditions as a user might need."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock response
|
||||
mock_search_client.search.return_value = [
|
||||
{
|
||||
"id": "doc1",
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps({"user_id": "user1", "run_id": "run123", "agent_id": "agent456"})
|
||||
}
|
||||
]
|
||||
|
||||
# Search with multiple filters (common in multi-tenant or segmented applications)
|
||||
filters = {
|
||||
"user_id": "user1",
|
||||
"run_id": "run123",
|
||||
"agent_id": "agent456"
|
||||
}
|
||||
results = instance.search([0.1, 0.2, 0.3], filters=filters)
|
||||
|
||||
# Verify search was called with the correct filter expression
|
||||
mock_search_client.search.assert_called_once()
|
||||
_, kwargs = mock_search_client.search.call_args
|
||||
assert "filter" in kwargs
|
||||
|
||||
# The filter should contain all three conditions
|
||||
filter_expr = kwargs["filter"]
|
||||
assert "user_id eq 'user1'" in filter_expr
|
||||
assert "run_id eq 'run123'" in filter_expr
|
||||
assert "agent_id eq 'agent456'" in filter_expr
|
||||
assert " and " in filter_expr # Conditions should be joined by AND
|
||||
|
||||
|
||||
def test_empty_search_results(azure_ai_search_instance):
|
||||
"""Test behavior when search returns no results (common edge case)."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to return empty results
|
||||
mock_search_client.search.return_value = []
|
||||
|
||||
# Search with a non-matching query
|
||||
results = instance.search([0.9, 0.9, 0.9], limit=5)
|
||||
|
||||
# Verify result handling
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_get_nonexistent_document(azure_ai_search_instance):
|
||||
"""Test behavior when getting a document that doesn't exist (should handle gracefully)."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to raise ResourceNotFoundError
|
||||
mock_search_client.get_document.side_effect = ResourceNotFoundError("Document not found")
|
||||
|
||||
# Get a non-existent document
|
||||
result = instance.get("nonexistent_id")
|
||||
|
||||
# Should return None instead of raising exception
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_azure_service_error(azure_ai_search_instance):
|
||||
"""Test handling of Azure service errors (important for robustness)."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to raise HttpResponseError
|
||||
http_error = HttpResponseError("Azure service is unavailable")
|
||||
mock_search_client.search.side_effect = http_error
|
||||
|
||||
# Attempt to search
|
||||
with pytest.raises(HttpResponseError):
|
||||
instance.search([0.1, 0.2, 0.3])
|
||||
|
||||
# Verify search was attempted
|
||||
mock_search_client.search.assert_called_once()
|
||||
|
||||
|
||||
def test_realistic_workflow(azure_ai_search_instance):
|
||||
"""Test a realistic workflow: insert → search → update → search again."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# 1. Insert a document
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
payload = {"user_id": "user1", "content": "Initial content"}
|
||||
doc_id = "workflow_doc"
|
||||
|
||||
mock_search_client.upload_documents.return_value = [{"status": True, "id": doc_id}]
|
||||
instance.insert([vector], [payload], [doc_id])
|
||||
|
||||
# 2. Search for the document
|
||||
mock_search_client.search.return_value = [
|
||||
{
|
||||
"id": doc_id,
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps(payload)
|
||||
}
|
||||
]
|
||||
results = instance.search(vector, filters={"user_id": "user1"})
|
||||
assert len(results) == 1
|
||||
assert results[0].id == doc_id
|
||||
|
||||
# 3. Update the document
|
||||
updated_payload = {"user_id": "user1", "content": "Updated content"}
|
||||
mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": doc_id}]
|
||||
instance.update(doc_id, payload=updated_payload)
|
||||
|
||||
# 4. Search again to get updated document
|
||||
mock_search_client.search.return_value = [
|
||||
{
|
||||
"id": doc_id,
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps(updated_payload)
|
||||
}
|
||||
]
|
||||
results = instance.search(vector, filters={"user_id": "user1"})
|
||||
assert len(results) == 1
|
||||
assert results[0].id == doc_id
|
||||
assert results[0].payload["content"] == "Updated content"
|
||||
|
||||
|
||||
def test_sanitize_special_characters(azure_ai_search_instance):
|
||||
"""Test that special characters in filter values are properly sanitized."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock response
|
||||
mock_search_client.search.return_value = [
|
||||
{
|
||||
"id": "doc1",
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps({"user_id": "user's-data"})
|
||||
}
|
||||
]
|
||||
|
||||
# Search with a filter that has special characters (common in real-world data)
|
||||
filters = {"user_id": "user's-data"}
|
||||
results = instance.search([0.1, 0.2, 0.3], filters=filters)
|
||||
|
||||
# Verify search was called with properly escaped filter
|
||||
mock_search_client.search.assert_called_once()
|
||||
_, kwargs = mock_search_client.search.call_args
|
||||
assert "filter" in kwargs
|
||||
|
||||
# The filter should have properly escaped single quotes
|
||||
filter_expr = kwargs["filter"]
|
||||
assert "user_id eq 'user''s-data'" in filter_expr
|
||||
|
||||
|
||||
def test_list_collections(azure_ai_search_instance):
|
||||
"""Test listing all collections/indexes (for management interfaces)."""
|
||||
instance, _, mock_index_client = azure_ai_search_instance
|
||||
|
||||
# List the collections
|
||||
collections = instance.list_cols()
|
||||
|
||||
# Verify the correct method was called
|
||||
mock_index_client.list_index_names.assert_called_once()
|
||||
|
||||
# Check the result
|
||||
assert collections == ["test-index"]
|
||||
|
||||
|
||||
def test_filter_with_numeric_values(azure_ai_search_instance):
|
||||
"""Test filtering with numeric values (common for faceted search)."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock response
|
||||
mock_search_client.search.return_value = [
|
||||
{
|
||||
"id": "doc1",
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps({"user_id": "user1", "count": 42})
|
||||
}
|
||||
]
|
||||
|
||||
# Search with a numeric filter
|
||||
# Note: In the actual implementation, numeric fields might need to be in the payload
|
||||
filters = {"count": 42}
|
||||
results = instance.search([0.1, 0.2, 0.3], filters=filters)
|
||||
|
||||
# Verify the filter expression
|
||||
mock_search_client.search.assert_called_once()
|
||||
_, kwargs = mock_search_client.search.call_args
|
||||
filter_expr = kwargs["filter"]
|
||||
assert "count eq 42" in filter_expr # No quotes for numbers
|
||||
|
||||
|
||||
def test_error_on_update_nonexistent(azure_ai_search_instance):
|
||||
"""Test behavior when updating a document that doesn't exist."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to return a failure for the update
|
||||
mock_search_client.merge_or_upload_documents.return_value = [
|
||||
{"status": False, "id": "nonexistent", "errorMessage": "Document not found"}
|
||||
]
|
||||
|
||||
# Attempt to update a non-existent document
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
instance.update("nonexistent", payload={"new": "data"})
|
||||
|
||||
assert "Update failed" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_different_compression_types():
|
||||
"""Test creating instances with different compression types (important for performance tuning)."""
|
||||
with patch("mem0.vector_stores.azure_ai_search.SearchClient"), \
|
||||
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"):
|
||||
|
||||
# Test with scalar compression
|
||||
scalar_instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="scalar-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=3,
|
||||
compression_type="scalar",
|
||||
use_float16=False
|
||||
)
|
||||
|
||||
# Test with no compression
|
||||
no_compression_instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="no-compression-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=3,
|
||||
compression_type=None,
|
||||
use_float16=False
|
||||
)
|
||||
|
||||
# No assertions needed - we're just verifying that initialization doesn't fail
|
||||
|
||||
|
||||
def test_high_dimensional_vectors():
|
||||
"""Test handling of high-dimensional vectors typical in AI embeddings."""
|
||||
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
|
||||
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"):
|
||||
|
||||
# Configure the mock client
|
||||
mock_search_client = MockSearchClient.return_value
|
||||
mock_search_client.upload_documents = Mock()
|
||||
mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}]
|
||||
|
||||
# Create an instance with higher dimensions like those from embedding models
|
||||
high_dim_instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="high-dim-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=1536, # Common for models like OpenAI's embeddings
|
||||
compression_type="binary", # Compression often used with high-dim vectors
|
||||
use_float16=True # Reduced precision often used for memory efficiency
|
||||
)
|
||||
|
||||
# Create a high-dimensional vector (stub with zeros for testing)
|
||||
high_dim_vector = [0.0] * 1536
|
||||
payload = {"user_id": "user1"}
|
||||
doc_id = "high_dim_doc"
|
||||
|
||||
# Insert the document
|
||||
high_dim_instance.insert([high_dim_vector], [payload], [doc_id])
|
||||
|
||||
# Verify the insert was called with the full vector
|
||||
mock_search_client.upload_documents.assert_called_once()
|
||||
args, _ = mock_search_client.upload_documents.call_args
|
||||
documents = args[0]
|
||||
assert len(documents[0]["vector"]) == 1536
|
||||
Reference in New Issue
Block a user