feat: enhance Azure AI Search Integration with Binary Quantization, Pre/Post Filter Options, and user agent header (#2354)

This commit is contained in:
Farzad Sunavala
2025-03-12 10:50:25 -05:00
committed by GitHub
parent 65f826e064
commit ba9c61938b
5 changed files with 788 additions and 83 deletions

View File

@@ -1,12 +1,14 @@
[Azure AI Search](https://learn.microsoft.com/en-us/azure/search/search-what-is-azure-search/) (formerly known as "Azure Cognitive Search") provides secure information retrieval at scale over user-owned content in traditional and generative AI search applications.
# Azure AI Search
### Usage
[Azure AI Search](https://learn.microsoft.com/azure/search/search-what-is-azure-search/) (formerly known as "Azure Cognitive Search") provides secure information retrieval at scale over user-owned content in traditional and generative AI search applications.
## Usage
```python
import os
from mem0 import Memory
os.environ["OPENAI_API_KEY"] = "sk-xx" #this key is used for embedding purpose
os.environ["OPENAI_API_KEY"] = "sk-xx" # This key is used for embedding purpose
config = {
"vector_store": {
@@ -15,8 +17,8 @@ config = {
"service_name": "ai-search-test",
"api_key": "*****",
"collection_name": "mem0",
"embedding_model_dims": 1536 ,
"use_compression": False
"embedding_model_dims": 1536,
"compression_type": "none"
}
}
}
@@ -25,20 +27,61 @@ m = Memory.from_config(config)
messages = [
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
{"role": "user", "content": "Im not a big fan of thriller movies but I love sci-fi movies."},
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
]
m.add(messages, user_id="alice", metadata={"category": "movies"})
```
### Config
## Advanced Usage
Let's see the available parameters for the `qdrant` config:
service_name (str): Azure Cognitive Search service name.
| Parameter | Description | Default Value |
| --- | --- | --- |
| `service_name` | Azure AI Search service name | `None` |
| `api_key` | API key of the Azure AI Search service | `None` |
| `collection_name` | The name of the collection/index to store the vectors, it will be created automatically if not exist | `mem0` |
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
| `use_compression` | Use scalar quantization vector compression | False |
```python
# Search with specific filter mode
result = m.search(
"sci-fi movies",
filters={"user_id": "alice"},
limit=5,
vector_filter_mode="preFilter" # Apply filters before vector search
)
# Using binary compression for large vector collections
config = {
"vector_store": {
"provider": "azure_ai_search",
"config": {
"service_name": "ai-search-test",
"api_key": "*****",
"collection_name": "mem0",
"embedding_model_dims": 1536,
"compression_type": "binary",
"use_float16": True # Use half precision for storage efficiency
}
}
}
```
## Configuration Parameters
| Parameter | Description | Default Value | Options |
| --- | --- | --- | --- |
| `service_name` | Azure AI Search service name | Required | - |
| `api_key` | API key of the Azure AI Search service | Required | - |
| `collection_name` | The name of the collection/index to store vectors | `mem0` | Any valid index name |
| `embedding_model_dims` | Dimensions of the embedding model | `1536` | Any integer value |
| `compression_type` | Type of vector compression to use | `none` | `none`, `scalar`, `binary` |
| `use_float16` | Store vectors in half precision (Edm.Half) | `False` | `True`, `False` |
## Notes on Configuration Options
- **compression_type**:
- `none`: No compression, uses full vector precision
- `scalar`: Scalar quantization with reasonable balance of speed and accuracy
- `binary`: Binary quantization for maximum compression with some accuracy trade-off
- **vector_filter_mode**:
- `preFilter`: Applies filters before vector search (faster)
- `postFilter`: Applies filters after vector search (may provide better relevance)
- **use_float16**: Using half precision (float16) reduces storage requirements but may slightly impact accuracy. Useful for very large vector collections.
- **Filterable Fields**: The implementation automatically extracts `user_id`, `run_id`, and `agent_id` fields from payloads for filtering.

View File

@@ -1,5 +1,6 @@
import json
import logging
import re
from typing import List, Optional
from pydantic import BaseModel
@@ -12,6 +13,7 @@ try:
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
BinaryQuantizationCompression,
HnswAlgorithmConfiguration,
ScalarQuantizationCompression,
SearchField,
@@ -24,7 +26,7 @@ try:
from azure.search.documents.models import VectorizedQuery
except ImportError:
raise ImportError(
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.1'."
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'."
)
logger = logging.getLogger(__name__)
@@ -37,43 +39,82 @@ class OutputData(BaseModel):
class AzureAISearch(VectorStoreBase):
def __init__(self, service_name, collection_name, api_key, embedding_model_dims, use_compression):
"""Initialize the Azure Cognitive Search vector store.
def __init__(
self,
service_name,
collection_name,
api_key,
embedding_model_dims,
compression_type: Optional[str] = None,
use_float16: bool = False,
):
"""
Initialize the Azure AI Search vector store.
Args:
service_name (str): Azure Cognitive Search service name.
service_name (str): Azure AI Search service name.
collection_name (str): Index name.
api_key (str): API key for the Azure Cognitive Search service.
api_key (str): API key for the Azure AI Search service.
embedding_model_dims (int): Dimension of the embedding vector.
use_compression (bool): Use scalar quantization vector compression
compression_type (Optional[str]): Specifies the type of quantization to use.
Allowed values are None (no quantization), "scalar", or "binary".
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
(Note: This flag is preserved from the initial implementation per feedback.)
"""
self.index_name = collection_name
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.use_compression = use_compression
# If compression_type is None, treat it as "none".
self.compression_type = (compression_type or "none").lower()
self.use_float16 = use_float16
self.search_client = SearchClient(
endpoint=f"https://{service_name}.search.windows.net",
index_name=self.index_name,
credential=AzureKeyCredential(api_key),
)
self.index_client = SearchIndexClient(
endpoint=f"https://{service_name}.search.windows.net", credential=AzureKeyCredential(api_key)
endpoint=f"https://{service_name}.search.windows.net",
credential=AzureKeyCredential(api_key),
)
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
self.create_col() # create the collection / index
def create_col(self):
"""Create a new index in Azure Cognitive Search."""
vector_dimensions = self.embedding_model_dims # Set this to the number of dimensions in your vector
if self.use_compression:
"""Create a new index in Azure AI Search."""
# Determine vector type based on use_float16 setting.
if self.use_float16:
vector_type = "Collection(Edm.Half)"
compression_name = "myCompression"
compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]
else:
vector_type = "Collection(Edm.Single)"
compression_name = None
compression_configurations = []
# Configure compression settings based on the specified compression_type.
compression_configurations = []
compression_name = None
if self.compression_type == "scalar":
compression_name = "myCompression"
# For SQ, rescoring defaults to True and oversampling defaults to 4.
compression_configurations = [
ScalarQuantizationCompression(
compression_name=compression_name
# rescoring defaults to True and oversampling defaults to 4
)
]
elif self.compression_type == "binary":
compression_name = "myCompression"
# For BQ, rescoring defaults to True and oversampling defaults to 10.
compression_configurations = [
BinaryQuantizationCompression(
compression_name=compression_name
# rescoring defaults to True and oversampling defaults to 10
)
]
# If no compression is desired, compression_configurations remains empty.
fields = [
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
@@ -82,8 +123,8 @@ class AzureAISearch(VectorStoreBase):
SearchField(
name="vector",
type=vector_type,
searchable=True,
vector_search_dimensions=vector_dimensions,
searchable=True,
vector_search_dimensions=self.embedding_model_dims,
vector_search_profile_name="my-vector-config",
),
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
@@ -91,7 +132,11 @@ class AzureAISearch(VectorStoreBase):
vector_search = VectorSearch(
profiles=[
VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")
VectorSearchProfile(
name="my-vector-config",
algorithm_configuration_name="my-algorithms-config",
compression_name=compression_name if self.compression_type != "none" else None
)
],
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
compressions=compression_configurations,
@@ -101,14 +146,16 @@ class AzureAISearch(VectorStoreBase):
def _generate_document(self, vector, payload, id):
document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
# Extract additional fields if they exist
# Extract additional fields if they exist.
for field in ["user_id", "run_id", "agent_id"]:
if field in payload:
document[field] = payload[field]
return document
# Note: Explicit "insert" calls may later be decoupled from memory management decisions.
def insert(self, vectors, payloads=None, ids=None):
"""Insert vectors into the index.
"""
Insert vectors into the index.
Args:
vectors (List[List[float]]): List of vectors to insert.
@@ -116,61 +163,87 @@ class AzureAISearch(VectorStoreBase):
ids (List[str], optional): List of IDs corresponding to vectors.
"""
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
documents = [
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
self._generate_document(vector, payload, id)
for id, vector, payload in zip(ids, vectors, payloads)
]
self.search_client.upload_documents(documents)
response = self.search_client.upload_documents(documents)
for doc in response:
if not doc.get("status", False):
raise Exception(f"Insert failed for document {doc.get('id')}: {doc}")
return response
def _sanitize_key(self, key: str) -> str:
return re.sub(r"[^\w]", "", key)
def _build_filter_expression(self, filters):
filter_conditions = []
for key, value in filters.items():
# If the value is a string, add quotes
safe_key = self._sanitize_key(key)
if isinstance(value, str):
condition = f"{key} eq '{value}'"
safe_value = value.replace("'", "''")
condition = f"{safe_key} eq '{safe_value}'"
else:
condition = f"{key} eq {value}"
condition = f"{safe_key} eq {value}"
filter_conditions.append(condition)
# Use 'and' to join multiple conditions
filter_expression = " and ".join(filter_conditions)
return filter_expression
def search(self, query, limit=5, filters=None):
"""Search for similar vectors.
def search(self, query, limit=5, filters=None, vector_filter_mode="preFilter"):
"""
Search for similar vectors.
Args:
query (List[float]): Query vectors.
query (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
vector_filter_mode (str): Determines whether filters are applied before or after the vector search.
Known values: "preFilter" (default) and "postFilter".
Returns:
list: Search results.
List[OutputData]: Search results.
"""
# Build filter expression
filter_expression = None
if filters:
filter_expression = self._build_filter_expression(filters)
vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector")
search_results = self.search_client.search(vector_queries=[vector_query], filter=filter_expression, top=limit)
vector_query = VectorizedQuery(
vector=query, k_nearest_neighbors=limit, fields="vector"
)
search_results = self.search_client.search(
vector_queries=[vector_query],
filter=filter_expression,
top=limit,
vector_filter_mode=vector_filter_mode,
)
results = []
for result in search_results:
payload = json.loads(result["payload"])
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
results.append(
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
return results
def delete(self, vector_id):
"""Delete a vector by ID.
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
self.search_client.delete_documents(documents=[{"id": vector_id}])
response = self.search_client.delete_documents(documents=[{"id": vector_id}])
for doc in response:
if not doc.get("status", False):
raise Exception(f"Delete failed for document {vector_id}: {doc}")
logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
return response
def update(self, vector_id, vector=None, payload=None):
"""Update a vector and its payload.
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
@@ -185,10 +258,15 @@ class AzureAISearch(VectorStoreBase):
document["payload"] = json_payload
for field in ["user_id", "run_id", "agent_id"]:
document[field] = payload.get(field)
self.search_client.merge_or_upload_documents(documents=[document])
response = self.search_client.merge_or_upload_documents(documents=[document])
for doc in response:
if not doc.get("status", False):
raise Exception(f"Update failed for document {vector_id}: {doc}")
return response
def get(self, vector_id) -> OutputData:
"""Retrieve a vector by ID.
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
@@ -200,35 +278,43 @@ class AzureAISearch(VectorStoreBase):
result = self.search_client.get_document(key=vector_id)
except ResourceNotFoundError:
return None
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
return OutputData(
id=result["id"], score=None, payload=json.loads(result["payload"])
)
def list_cols(self) -> List[str]:
"""List all collections (indexes).
"""
List all collections (indexes).
Returns:
List[str]: List of index names.
"""
indexes = self.index_client.list_indexes()
return [index.name for index in indexes]
try:
names = self.index_client.list_index_names()
except AttributeError:
names = [index.name for index in self.index_client.list_indexes()]
return names
def delete_col(self):
"""Delete the index."""
self.index_client.delete_index(self.index_name)
def col_info(self):
"""Get information about the index.
"""
Get information about the index.
Returns:
Dict[str, Any]: Index information.
dict: Index information.
"""
index = self.index_client.get_index(self.index_name)
return {"name": index.name, "fields": index.fields}
def list(self, filters=None, limit=100):
"""List all vectors in the index.
"""
List all vectors in the index.
Args:
filters (Dict, optional): Filters to apply to the list.
filters (dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
@@ -238,13 +324,18 @@ class AzureAISearch(VectorStoreBase):
if filters:
filter_expression = self._build_filter_expression(filters)
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
search_results = self.search_client.search(
search_text="*", filter=filter_expression, top=limit
)
results = []
for result in search_results:
payload = json.loads(result["payload"])
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return [results]
results.append(
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
return results
def __del__(self):
"""Close the search client when the object is deleted."""

90
poetry.lock generated
View File

@@ -112,7 +112,7 @@ propcache = ">=0.2.0"
yarl = ">=1.17.0,<2.0"
[package.extras]
speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"]
speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>=3.2.0) ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
[[package]]
name = "aiosignal"
@@ -158,7 +158,7 @@ typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""}
[package.extras]
doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"]
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"]
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\" and python_version < \"3.14\""]
trio = ["trio (>=0.26.1)"]
[[package]]
@@ -184,12 +184,62 @@ files = [
]
[package.extras]
benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
cov = ["cloudpickle ; platform_python_implementation == \"CPython\"", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
dev = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"]
tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""]
[[package]]
name = "azure-common"
version = "1.1.28"
description = "Microsoft Azure Client Library for Python (Common)"
optional = false
python-versions = "*"
groups = ["main"]
files = [
{file = "azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3"},
{file = "azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad"},
]
[[package]]
name = "azure-core"
version = "1.32.0"
description = "Microsoft Azure Core Library for Python"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "azure_core-1.32.0-py3-none-any.whl", hash = "sha256:eac191a0efb23bfa83fddf321b27b122b4ec847befa3091fa736a5c32c50d7b4"},
{file = "azure_core-1.32.0.tar.gz", hash = "sha256:22b3c35d6b2dae14990f6c1be2912bf23ffe50b220e708a28ab1bb92b1c730e5"},
]
[package.dependencies]
requests = ">=2.21.0"
six = ">=1.11.0"
typing-extensions = ">=4.6.0"
[package.extras]
aio = ["aiohttp (>=3.0)"]
[[package]]
name = "azure-search-documents"
version = "11.5.2"
description = "Microsoft Azure Cognitive Search Client Library for Python"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "azure_search_documents-11.5.2-py3-none-any.whl", hash = "sha256:c949d011008a4b0bcee3db91132741b4e4d50ddb3f7e2f48944d949d4b413b11"},
{file = "azure_search_documents-11.5.2.tar.gz", hash = "sha256:98977dd1fa4978d3b7d8891a0856b3becb6f02cc07ff2e1ea40b9c7254ada315"},
]
[package.dependencies]
azure-common = ">=1.1"
azure-core = ">=1.28.0"
isodate = ">=0.6.0"
typing-extensions = ">=4.6.0"
[[package]]
name = "backoff"
@@ -868,7 +918,7 @@ httpcore = "==1.*"
idna = "*"
[package.extras]
brotli = ["brotli", "brotlicffi"]
brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
@@ -910,6 +960,18 @@ files = [
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]]
name = "isodate"
version = "0.7.2"
description = "An ISO 8601 date/time/duration parser and formatter"
optional = false
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15"},
{file = "isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6"},
]
[[package]]
name = "isort"
version = "5.13.2"
@@ -1799,7 +1861,7 @@ typing-extensions = ">=4.12.2"
[package.extras]
email = ["email-validator (>=2.0.0)"]
timezone = ["tzdata"]
timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""]
[[package]]
name = "pydantic-core"
@@ -2213,13 +2275,13 @@ files = [
]
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"]
core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"]
[[package]]
name = "six"
@@ -2449,7 +2511,7 @@ files = [
]
[package.extras]
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""]
h2 = ["h2 (>=4,<5)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]

View File

@@ -25,6 +25,7 @@ sqlalchemy = "^2.0.31"
langchain-neo4j = "^0.4.0"
neo4j = "^5.23.1"
rank-bm25 = "^0.2.2"
azure-search-documents = "^11.5.0"
psycopg2-binary = "^2.9.10"
[tool.poetry.extras]

View File

@@ -0,0 +1,508 @@
import json
from unittest.mock import Mock, patch
import pytest
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError
# Import the AzureAISearch class and OutputData model from your module.
from mem0.vector_stores.azure_ai_search import AzureAISearch
# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch.
@pytest.fixture
def mock_clients():
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient:
# Create mocked instances for search and index clients.
mock_search_client = MockSearchClient.return_value
mock_index_client = MockIndexClient.return_value
# Stub required methods on search_client.
mock_search_client.upload_documents = Mock()
mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}]
mock_search_client.search = Mock()
mock_search_client.delete_documents = Mock()
mock_search_client.delete_documents.return_value = [{"status": True, "id": "doc1"}]
mock_search_client.merge_or_upload_documents = Mock()
mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": "doc1"}]
mock_search_client.get_document = Mock()
mock_search_client.close = Mock()
# Stub required methods on index_client.
mock_index_client.create_or_update_index = Mock()
mock_index_client.list_indexes = Mock(return_value=[])
mock_index_client.list_index_names = Mock(return_value=["test-index"])
mock_index_client.delete_index = Mock()
# For col_info() we assume get_index returns an object with name and fields attributes.
fake_index = Mock()
fake_index.name = "test-index"
fake_index.fields = ["id", "vector", "payload", "user_id", "run_id", "agent_id"]
mock_index_client.get_index = Mock(return_value=fake_index)
mock_index_client.close = Mock()
yield mock_search_client, mock_index_client
@pytest.fixture
def azure_ai_search_instance(mock_clients):
mock_search_client, mock_index_client = mock_clients
# Create an instance with dummy parameters.
instance = AzureAISearch(
service_name="test-service",
collection_name="test-index",
api_key="test-api-key",
embedding_model_dims=3,
compression_type="binary", # testing binary quantization option
use_float16=True
)
# Return instance and clients for verification.
return instance, mock_search_client, mock_index_client
# --- Original tests ---
def test_create_col(azure_ai_search_instance):
instance, mock_search_client, mock_index_client = azure_ai_search_instance
# Upon initialization, create_col should be called.
mock_index_client.create_or_update_index.assert_called_once()
# Optionally, you could inspect the call arguments for vector type.
def test_insert(azure_ai_search_instance):
instance, mock_search_client, _ = azure_ai_search_instance
vectors = [[0.1, 0.2, 0.3]]
payloads = [{"user_id": "user1", "run_id": "run1"}]
ids = ["doc1"]
instance.insert(vectors, payloads, ids)
mock_search_client.upload_documents.assert_called_once()
args, _ = mock_search_client.upload_documents.call_args
documents = args[0]
# Update expected_doc to include extra fields from payload.
expected_doc = {
"id": "doc1",
"vector": [0.1, 0.2, 0.3],
"payload": json.dumps({"user_id": "user1", "run_id": "run1"}),
"user_id": "user1",
"run_id": "run1"
}
assert documents[0] == expected_doc
def test_search_preFilter(azure_ai_search_instance):
instance, mock_search_client, _ = azure_ai_search_instance
# Setup a fake search result returned by the mocked search method.
fake_result = {
"id": "doc1",
"@search.score": 0.95,
"payload": json.dumps({"user_id": "user1"})
}
# Configure the mock to return an iterator (list) with fake_result.
mock_search_client.search.return_value = [fake_result]
query_vector = [0.1, 0.2, 0.3]
results = instance.search(query_vector, limit=1, filters={"user_id": "user1"}, vector_filter_mode="preFilter")
# Verify that the search method was called with vector_filter_mode="preFilter".
mock_search_client.search.assert_called_once()
_, called_kwargs = mock_search_client.search.call_args
assert called_kwargs.get("vector_filter_mode") == "preFilter"
# Verify that the output is parsed correctly.
assert len(results) == 1
assert results[0].id == "doc1"
assert results[0].score == 0.95
assert results[0].payload == {"user_id": "user1"}
def test_search_postFilter(azure_ai_search_instance):
instance, mock_search_client, _ = azure_ai_search_instance
# Setup a fake search result for postFilter.
fake_result = {
"id": "doc2",
"@search.score": 0.85,
"payload": json.dumps({"user_id": "user2"})
}
mock_search_client.search.return_value = [fake_result]
query_vector = [0.4, 0.5, 0.6]
results = instance.search(query_vector, limit=1, filters={"user_id": "user2"}, vector_filter_mode="postFilter")
mock_search_client.search.assert_called_once()
_, called_kwargs = mock_search_client.search.call_args
assert called_kwargs.get("vector_filter_mode") == "postFilter"
assert len(results) == 1
assert results[0].id == "doc2"
assert results[0].score == 0.85
assert results[0].payload == {"user_id": "user2"}
def test_delete(azure_ai_search_instance):
instance, mock_search_client, _ = azure_ai_search_instance
vector_id = "doc1"
# Set delete_documents to return an iterable with a successful response.
mock_search_client.delete_documents.return_value = [{"status": True, "id": vector_id}]
instance.delete(vector_id)
mock_search_client.delete_documents.assert_called_once_with(documents=[{"id": vector_id}])
def test_update(azure_ai_search_instance):
instance, mock_search_client, _ = azure_ai_search_instance
vector_id = "doc1"
new_vector = [0.7, 0.8, 0.9]
new_payload = {"user_id": "updated"}
# Set merge_or_upload_documents to return an iterable with a successful response.
mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": vector_id}]
instance.update(vector_id, vector=new_vector, payload=new_payload)
mock_search_client.merge_or_upload_documents.assert_called_once()
kwargs = mock_search_client.merge_or_upload_documents.call_args.kwargs
document = kwargs["documents"][0]
assert document["id"] == vector_id
assert document["vector"] == new_vector
assert document["payload"] == json.dumps(new_payload)
# The update method will also add the 'user_id' field.
assert document["user_id"] == "updated"
def test_get(azure_ai_search_instance):
instance, mock_search_client, _ = azure_ai_search_instance
fake_result = {
"id": "doc1",
"payload": json.dumps({"user_id": "user1"})
}
mock_search_client.get_document.return_value = fake_result
result = instance.get("doc1")
mock_search_client.get_document.assert_called_once_with(key="doc1")
assert result.id == "doc1"
assert result.payload == {"user_id": "user1"}
assert result.score is None
def test_list(azure_ai_search_instance):
instance, mock_search_client, _ = azure_ai_search_instance
fake_result = {
"id": "doc1",
"@search.score": 0.99,
"payload": json.dumps({"user_id": "user1"})
}
mock_search_client.search.return_value = [fake_result]
# Call list with a simple filter.
results = instance.list(filters={"user_id": "user1"}, limit=1)
# Verify the search method was called with the proper parameters.
expected_filter = instance._build_filter_expression({"user_id": "user1"})
mock_search_client.search.assert_called_once_with(
search_text="*",
filter=expected_filter,
top=1
)
assert isinstance(results, list)
assert len(results) == 1
assert results[0].id == "doc1"
# --- New tests for practical end-user scenarios ---
def test_bulk_insert(azure_ai_search_instance):
"""Test inserting a batch of documents (common for initial data loading)."""
instance, mock_search_client, _ = azure_ai_search_instance
# Create a batch of 10 documents
num_docs = 10
vectors = [[0.1, 0.2, 0.3] for _ in range(num_docs)]
payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)]
ids = [f"doc{i}" for i in range(num_docs)]
# Configure mock to return success for all documents
mock_search_client.upload_documents.return_value = [
{"status": True, "id": id_val} for id_val in ids
]
# Insert the batch
instance.insert(vectors, payloads, ids)
# Verify the call
mock_search_client.upload_documents.assert_called_once()
args, _ = mock_search_client.upload_documents.call_args
documents = args[0]
assert len(documents) == num_docs
# Verify the first and last document
assert documents[0]["id"] == "doc0"
assert documents[-1]["id"] == f"doc{num_docs-1}"
def test_insert_error_handling(azure_ai_search_instance):
"""Test how the class handles Azure errors during insertion."""
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock to return a failure for one document
mock_search_client.upload_documents.return_value = [
{"status": False, "id": "doc1", "errorMessage": "Azure error"}
]
vectors = [[0.1, 0.2, 0.3]]
payloads = [{"user_id": "user1"}]
ids = ["doc1"]
# Exception should be raised
with pytest.raises(Exception) as exc_info:
instance.insert(vectors, payloads, ids)
assert "Insert failed" in str(exc_info.value)
def test_search_with_complex_filters(azure_ai_search_instance):
"""Test searching with multiple filter conditions as a user might need."""
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock response
mock_search_client.search.return_value = [
{
"id": "doc1",
"@search.score": 0.95,
"payload": json.dumps({"user_id": "user1", "run_id": "run123", "agent_id": "agent456"})
}
]
# Search with multiple filters (common in multi-tenant or segmented applications)
filters = {
"user_id": "user1",
"run_id": "run123",
"agent_id": "agent456"
}
results = instance.search([0.1, 0.2, 0.3], filters=filters)
# Verify search was called with the correct filter expression
mock_search_client.search.assert_called_once()
_, kwargs = mock_search_client.search.call_args
assert "filter" in kwargs
# The filter should contain all three conditions
filter_expr = kwargs["filter"]
assert "user_id eq 'user1'" in filter_expr
assert "run_id eq 'run123'" in filter_expr
assert "agent_id eq 'agent456'" in filter_expr
assert " and " in filter_expr # Conditions should be joined by AND
def test_empty_search_results(azure_ai_search_instance):
"""Test behavior when search returns no results (common edge case)."""
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock to return empty results
mock_search_client.search.return_value = []
# Search with a non-matching query
results = instance.search([0.9, 0.9, 0.9], limit=5)
# Verify result handling
assert len(results) == 0
def test_get_nonexistent_document(azure_ai_search_instance):
"""Test behavior when getting a document that doesn't exist (should handle gracefully)."""
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock to raise ResourceNotFoundError
mock_search_client.get_document.side_effect = ResourceNotFoundError("Document not found")
# Get a non-existent document
result = instance.get("nonexistent_id")
# Should return None instead of raising exception
assert result is None
def test_azure_service_error(azure_ai_search_instance):
"""Test handling of Azure service errors (important for robustness)."""
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock to raise HttpResponseError
http_error = HttpResponseError("Azure service is unavailable")
mock_search_client.search.side_effect = http_error
# Attempt to search
with pytest.raises(HttpResponseError):
instance.search([0.1, 0.2, 0.3])
# Verify search was attempted
mock_search_client.search.assert_called_once()
def test_realistic_workflow(azure_ai_search_instance):
"""Test a realistic workflow: insert → search → update → search again."""
instance, mock_search_client, _ = azure_ai_search_instance
# 1. Insert a document
vector = [0.1, 0.2, 0.3]
payload = {"user_id": "user1", "content": "Initial content"}
doc_id = "workflow_doc"
mock_search_client.upload_documents.return_value = [{"status": True, "id": doc_id}]
instance.insert([vector], [payload], [doc_id])
# 2. Search for the document
mock_search_client.search.return_value = [
{
"id": doc_id,
"@search.score": 0.95,
"payload": json.dumps(payload)
}
]
results = instance.search(vector, filters={"user_id": "user1"})
assert len(results) == 1
assert results[0].id == doc_id
# 3. Update the document
updated_payload = {"user_id": "user1", "content": "Updated content"}
mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": doc_id}]
instance.update(doc_id, payload=updated_payload)
# 4. Search again to get updated document
mock_search_client.search.return_value = [
{
"id": doc_id,
"@search.score": 0.95,
"payload": json.dumps(updated_payload)
}
]
results = instance.search(vector, filters={"user_id": "user1"})
assert len(results) == 1
assert results[0].id == doc_id
assert results[0].payload["content"] == "Updated content"
def test_sanitize_special_characters(azure_ai_search_instance):
"""Test that special characters in filter values are properly sanitized."""
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock response
mock_search_client.search.return_value = [
{
"id": "doc1",
"@search.score": 0.95,
"payload": json.dumps({"user_id": "user's-data"})
}
]
# Search with a filter that has special characters (common in real-world data)
filters = {"user_id": "user's-data"}
results = instance.search([0.1, 0.2, 0.3], filters=filters)
# Verify search was called with properly escaped filter
mock_search_client.search.assert_called_once()
_, kwargs = mock_search_client.search.call_args
assert "filter" in kwargs
# The filter should have properly escaped single quotes
filter_expr = kwargs["filter"]
assert "user_id eq 'user''s-data'" in filter_expr
def test_list_collections(azure_ai_search_instance):
"""Test listing all collections/indexes (for management interfaces)."""
instance, _, mock_index_client = azure_ai_search_instance
# List the collections
collections = instance.list_cols()
# Verify the correct method was called
mock_index_client.list_index_names.assert_called_once()
# Check the result
assert collections == ["test-index"]
def test_filter_with_numeric_values(azure_ai_search_instance):
"""Test filtering with numeric values (common for faceted search)."""
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock response
mock_search_client.search.return_value = [
{
"id": "doc1",
"@search.score": 0.95,
"payload": json.dumps({"user_id": "user1", "count": 42})
}
]
# Search with a numeric filter
# Note: In the actual implementation, numeric fields might need to be in the payload
filters = {"count": 42}
results = instance.search([0.1, 0.2, 0.3], filters=filters)
# Verify the filter expression
mock_search_client.search.assert_called_once()
_, kwargs = mock_search_client.search.call_args
filter_expr = kwargs["filter"]
assert "count eq 42" in filter_expr # No quotes for numbers
def test_error_on_update_nonexistent(azure_ai_search_instance):
"""Test behavior when updating a document that doesn't exist."""
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock to return a failure for the update
mock_search_client.merge_or_upload_documents.return_value = [
{"status": False, "id": "nonexistent", "errorMessage": "Document not found"}
]
# Attempt to update a non-existent document
with pytest.raises(Exception) as exc_info:
instance.update("nonexistent", payload={"new": "data"})
assert "Update failed" in str(exc_info.value)
def test_different_compression_types():
"""Test creating instances with different compression types (important for performance tuning)."""
with patch("mem0.vector_stores.azure_ai_search.SearchClient"), \
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"):
# Test with scalar compression
scalar_instance = AzureAISearch(
service_name="test-service",
collection_name="scalar-index",
api_key="test-api-key",
embedding_model_dims=3,
compression_type="scalar",
use_float16=False
)
# Test with no compression
no_compression_instance = AzureAISearch(
service_name="test-service",
collection_name="no-compression-index",
api_key="test-api-key",
embedding_model_dims=3,
compression_type=None,
use_float16=False
)
# No assertions needed - we're just verifying that initialization doesn't fail
def test_high_dimensional_vectors():
"""Test handling of high-dimensional vectors typical in AI embeddings."""
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"):
# Configure the mock client
mock_search_client = MockSearchClient.return_value
mock_search_client.upload_documents = Mock()
mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}]
# Create an instance with higher dimensions like those from embedding models
high_dim_instance = AzureAISearch(
service_name="test-service",
collection_name="high-dim-index",
api_key="test-api-key",
embedding_model_dims=1536, # Common for models like OpenAI's embeddings
compression_type="binary", # Compression often used with high-dim vectors
use_float16=True # Reduced precision often used for memory efficiency
)
# Create a high-dimensional vector (stub with zeros for testing)
high_dim_vector = [0.0] * 1536
payload = {"user_id": "user1"}
doc_id = "high_dim_doc"
# Insert the document
high_dim_instance.insert([high_dim_vector], [payload], [doc_id])
# Verify the insert was called with the full vector
mock_search_client.upload_documents.assert_called_once()
args, _ = mock_search_client.upload_documents.call_args
documents = args[0]
assert len(documents[0]["vector"]) == 1536