diff --git a/docs/components/vectordbs/dbs/azure_ai_search.mdx b/docs/components/vectordbs/dbs/azure_ai_search.mdx index b109759c..9a02e8cc 100644 --- a/docs/components/vectordbs/dbs/azure_ai_search.mdx +++ b/docs/components/vectordbs/dbs/azure_ai_search.mdx @@ -1,12 +1,14 @@ -[Azure AI Search](https://learn.microsoft.com/en-us/azure/search/search-what-is-azure-search/) (formerly known as "Azure Cognitive Search") provides secure information retrieval at scale over user-owned content in traditional and generative AI search applications. +# Azure AI Search -### Usage +[Azure AI Search](https://learn.microsoft.com/azure/search/search-what-is-azure-search/) (formerly known as "Azure Cognitive Search") provides secure information retrieval at scale over user-owned content in traditional and generative AI search applications. + +## Usage ```python import os from mem0 import Memory -os.environ["OPENAI_API_KEY"] = "sk-xx" #this key is used for embedding purpose +os.environ["OPENAI_API_KEY"] = "sk-xx" # This key is used for embedding purpose config = { "vector_store": { @@ -15,8 +17,8 @@ config = { "service_name": "ai-search-test", "api_key": "*****", "collection_name": "mem0", - "embedding_model_dims": 1536 , - "use_compression": False + "embedding_model_dims": 1536, + "compression_type": "none" } } } @@ -25,20 +27,61 @@ m = Memory.from_config(config) messages = [ {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, - {"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} ] m.add(messages, user_id="alice", metadata={"category": "movies"}) ``` -### Config +## Advanced Usage -Let's see the available parameters for the `qdrant` config: -service_name (str): Azure Cognitive Search service name. -| Parameter | Description | Default Value | -| --- | --- | --- | -| `service_name` | Azure AI Search service name | `None` | -| `api_key` | API key of the Azure AI Search service | `None` | -| `collection_name` | The name of the collection/index to store the vectors, it will be created automatically if not exist | `mem0` | -| `embedding_model_dims` | Dimensions of the embedding model | `1536` | -| `use_compression` | Use scalar quantization vector compression | False | \ No newline at end of file +```python +# Search with specific filter mode +result = m.search( + "sci-fi movies", + filters={"user_id": "alice"}, + limit=5, + vector_filter_mode="preFilter" # Apply filters before vector search +) + +# Using binary compression for large vector collections +config = { + "vector_store": { + "provider": "azure_ai_search", + "config": { + "service_name": "ai-search-test", + "api_key": "*****", + "collection_name": "mem0", + "embedding_model_dims": 1536, + "compression_type": "binary", + "use_float16": True # Use half precision for storage efficiency + } + } +} +``` + +## Configuration Parameters + +| Parameter | Description | Default Value | Options | +| --- | --- | --- | --- | +| `service_name` | Azure AI Search service name | Required | - | +| `api_key` | API key of the Azure AI Search service | Required | - | +| `collection_name` | The name of the collection/index to store vectors | `mem0` | Any valid index name | +| `embedding_model_dims` | Dimensions of the embedding model | `1536` | Any integer value | +| `compression_type` | Type of vector compression to use | `none` | `none`, `scalar`, `binary` | +| `use_float16` | Store vectors in half precision (Edm.Half) | `False` | `True`, `False` | + +## Notes on Configuration Options + +- **compression_type**: + - `none`: No compression, uses full vector precision + - `scalar`: Scalar quantization with reasonable balance of speed and accuracy + - `binary`: Binary quantization for maximum compression with some accuracy trade-off + +- **vector_filter_mode**: + - `preFilter`: Applies filters before vector search (faster) + - `postFilter`: Applies filters after vector search (may provide better relevance) + +- **use_float16**: Using half precision (float16) reduces storage requirements but may slightly impact accuracy. Useful for very large vector collections. + +- **Filterable Fields**: The implementation automatically extracts `user_id`, `run_id`, and `agent_id` fields from payloads for filtering. \ No newline at end of file diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py index 632aef59..03a47633 100644 --- a/mem0/vector_stores/azure_ai_search.py +++ b/mem0/vector_stores/azure_ai_search.py @@ -1,5 +1,6 @@ import json import logging +import re from typing import List, Optional from pydantic import BaseModel @@ -12,6 +13,7 @@ try: from azure.search.documents import SearchClient from azure.search.documents.indexes import SearchIndexClient from azure.search.documents.indexes.models import ( + BinaryQuantizationCompression, HnswAlgorithmConfiguration, ScalarQuantizationCompression, SearchField, @@ -24,7 +26,7 @@ try: from azure.search.documents.models import VectorizedQuery except ImportError: raise ImportError( - "The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.1'." + "The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'." ) logger = logging.getLogger(__name__) @@ -37,43 +39,82 @@ class OutputData(BaseModel): class AzureAISearch(VectorStoreBase): - def __init__(self, service_name, collection_name, api_key, embedding_model_dims, use_compression): - """Initialize the Azure Cognitive Search vector store. + def __init__( + self, + service_name, + collection_name, + api_key, + embedding_model_dims, + compression_type: Optional[str] = None, + use_float16: bool = False, + ): + """ + Initialize the Azure AI Search vector store. Args: - service_name (str): Azure Cognitive Search service name. + service_name (str): Azure AI Search service name. collection_name (str): Index name. - api_key (str): API key for the Azure Cognitive Search service. + api_key (str): API key for the Azure AI Search service. embedding_model_dims (int): Dimension of the embedding vector. - use_compression (bool): Use scalar quantization vector compression + compression_type (Optional[str]): Specifies the type of quantization to use. + Allowed values are None (no quantization), "scalar", or "binary". + use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single). + (Note: This flag is preserved from the initial implementation per feedback.) """ self.index_name = collection_name self.collection_name = collection_name self.embedding_model_dims = embedding_model_dims - self.use_compression = use_compression + # If compression_type is None, treat it as "none". + self.compression_type = (compression_type or "none").lower() + self.use_float16 = use_float16 + self.search_client = SearchClient( endpoint=f"https://{service_name}.search.windows.net", index_name=self.index_name, credential=AzureKeyCredential(api_key), ) self.index_client = SearchIndexClient( - endpoint=f"https://{service_name}.search.windows.net", credential=AzureKeyCredential(api_key) + endpoint=f"https://{service_name}.search.windows.net", + credential=AzureKeyCredential(api_key), ) + + self.search_client._client._config.user_agent_policy.add_user_agent("mem0") + self.index_client._client._config.user_agent_policy.add_user_agent("mem0") + self.create_col() # create the collection / index def create_col(self): - """Create a new index in Azure Cognitive Search.""" - vector_dimensions = self.embedding_model_dims # Set this to the number of dimensions in your vector - - if self.use_compression: + """Create a new index in Azure AI Search.""" + # Determine vector type based on use_float16 setting. + if self.use_float16: vector_type = "Collection(Edm.Half)" - compression_name = "myCompression" - compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)] else: vector_type = "Collection(Edm.Single)" - compression_name = None - compression_configurations = [] + # Configure compression settings based on the specified compression_type. + compression_configurations = [] + compression_name = None + if self.compression_type == "scalar": + compression_name = "myCompression" + # For SQ, rescoring defaults to True and oversampling defaults to 4. + compression_configurations = [ + ScalarQuantizationCompression( + compression_name=compression_name + # rescoring defaults to True and oversampling defaults to 4 + ) + ] + elif self.compression_type == "binary": + compression_name = "myCompression" + # For BQ, rescoring defaults to True and oversampling defaults to 10. + compression_configurations = [ + BinaryQuantizationCompression( + compression_name=compression_name + # rescoring defaults to True and oversampling defaults to 10 + ) + ] + # If no compression is desired, compression_configurations remains empty. + + fields = [ SimpleField(name="id", type=SearchFieldDataType.String, key=True), SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True), @@ -82,8 +123,8 @@ class AzureAISearch(VectorStoreBase): SearchField( name="vector", type=vector_type, - searchable=True, - vector_search_dimensions=vector_dimensions, + searchable=True, + vector_search_dimensions=self.embedding_model_dims, vector_search_profile_name="my-vector-config", ), SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True), @@ -91,7 +132,11 @@ class AzureAISearch(VectorStoreBase): vector_search = VectorSearch( profiles=[ - VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config") + VectorSearchProfile( + name="my-vector-config", + algorithm_configuration_name="my-algorithms-config", + compression_name=compression_name if self.compression_type != "none" else None + ) ], algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")], compressions=compression_configurations, @@ -101,14 +146,16 @@ class AzureAISearch(VectorStoreBase): def _generate_document(self, vector, payload, id): document = {"id": id, "vector": vector, "payload": json.dumps(payload)} - # Extract additional fields if they exist + # Extract additional fields if they exist. for field in ["user_id", "run_id", "agent_id"]: if field in payload: document[field] = payload[field] return document + # Note: Explicit "insert" calls may later be decoupled from memory management decisions. def insert(self, vectors, payloads=None, ids=None): - """Insert vectors into the index. + """ + Insert vectors into the index. Args: vectors (List[List[float]]): List of vectors to insert. @@ -116,61 +163,87 @@ class AzureAISearch(VectorStoreBase): ids (List[str], optional): List of IDs corresponding to vectors. """ logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}") - documents = [ - self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads) + self._generate_document(vector, payload, id) + for id, vector, payload in zip(ids, vectors, payloads) ] - self.search_client.upload_documents(documents) + response = self.search_client.upload_documents(documents) + for doc in response: + if not doc.get("status", False): + raise Exception(f"Insert failed for document {doc.get('id')}: {doc}") + return response + + def _sanitize_key(self, key: str) -> str: + return re.sub(r"[^\w]", "", key) def _build_filter_expression(self, filters): filter_conditions = [] for key, value in filters.items(): - # If the value is a string, add quotes + safe_key = self._sanitize_key(key) if isinstance(value, str): - condition = f"{key} eq '{value}'" + safe_value = value.replace("'", "''") + condition = f"{safe_key} eq '{safe_value}'" else: - condition = f"{key} eq {value}" + condition = f"{safe_key} eq {value}" filter_conditions.append(condition) - # Use 'and' to join multiple conditions filter_expression = " and ".join(filter_conditions) return filter_expression - def search(self, query, limit=5, filters=None): - """Search for similar vectors. + def search(self, query, limit=5, filters=None, vector_filter_mode="preFilter"): + """ + Search for similar vectors. Args: - query (List[float]): Query vectors. + query (List[float]): Query vector. limit (int, optional): Number of results to return. Defaults to 5. filters (Dict, optional): Filters to apply to the search. Defaults to None. + vector_filter_mode (str): Determines whether filters are applied before or after the vector search. + Known values: "preFilter" (default) and "postFilter". Returns: - list: Search results. + List[OutputData]: Search results. """ - # Build filter expression filter_expression = None if filters: filter_expression = self._build_filter_expression(filters) - vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector") - search_results = self.search_client.search(vector_queries=[vector_query], filter=filter_expression, top=limit) + vector_query = VectorizedQuery( + vector=query, k_nearest_neighbors=limit, fields="vector" + ) + search_results = self.search_client.search( + vector_queries=[vector_query], + filter=filter_expression, + top=limit, + vector_filter_mode=vector_filter_mode, + ) results = [] for result in search_results: payload = json.loads(result["payload"]) - results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) + results.append( + OutputData( + id=result["id"], score=result["@search.score"], payload=payload + ) + ) return results def delete(self, vector_id): - """Delete a vector by ID. + """ + Delete a vector by ID. Args: vector_id (str): ID of the vector to delete. """ - self.search_client.delete_documents(documents=[{"id": vector_id}]) + response = self.search_client.delete_documents(documents=[{"id": vector_id}]) + for doc in response: + if not doc.get("status", False): + raise Exception(f"Delete failed for document {vector_id}: {doc}") logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.") + return response def update(self, vector_id, vector=None, payload=None): - """Update a vector and its payload. + """ + Update a vector and its payload. Args: vector_id (str): ID of the vector to update. @@ -185,10 +258,15 @@ class AzureAISearch(VectorStoreBase): document["payload"] = json_payload for field in ["user_id", "run_id", "agent_id"]: document[field] = payload.get(field) - self.search_client.merge_or_upload_documents(documents=[document]) + response = self.search_client.merge_or_upload_documents(documents=[document]) + for doc in response: + if not doc.get("status", False): + raise Exception(f"Update failed for document {vector_id}: {doc}") + return response def get(self, vector_id) -> OutputData: - """Retrieve a vector by ID. + """ + Retrieve a vector by ID. Args: vector_id (str): ID of the vector to retrieve. @@ -200,35 +278,43 @@ class AzureAISearch(VectorStoreBase): result = self.search_client.get_document(key=vector_id) except ResourceNotFoundError: return None - return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])) + return OutputData( + id=result["id"], score=None, payload=json.loads(result["payload"]) + ) def list_cols(self) -> List[str]: - """List all collections (indexes). + """ + List all collections (indexes). Returns: List[str]: List of index names. """ - indexes = self.index_client.list_indexes() - return [index.name for index in indexes] + try: + names = self.index_client.list_index_names() + except AttributeError: + names = [index.name for index in self.index_client.list_indexes()] + return names def delete_col(self): """Delete the index.""" self.index_client.delete_index(self.index_name) def col_info(self): - """Get information about the index. + """ + Get information about the index. Returns: - Dict[str, Any]: Index information. + dict: Index information. """ index = self.index_client.get_index(self.index_name) return {"name": index.name, "fields": index.fields} def list(self, filters=None, limit=100): - """List all vectors in the index. + """ + List all vectors in the index. Args: - filters (Dict, optional): Filters to apply to the list. + filters (dict, optional): Filters to apply to the list. limit (int, optional): Number of vectors to return. Defaults to 100. Returns: @@ -238,13 +324,18 @@ class AzureAISearch(VectorStoreBase): if filters: filter_expression = self._build_filter_expression(filters) - search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit) + search_results = self.search_client.search( + search_text="*", filter=filter_expression, top=limit + ) results = [] for result in search_results: payload = json.loads(result["payload"]) - results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) - - return [results] + results.append( + OutputData( + id=result["id"], score=result["@search.score"], payload=payload + ) + ) + return results def __del__(self): """Close the search client when the object is deleted.""" diff --git a/poetry.lock b/poetry.lock index 731915c3..96c789b2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -112,7 +112,7 @@ propcache = ">=0.2.0" yarl = ">=1.17.0,<2.0" [package.extras] -speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] +speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>=3.2.0) ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""] [[package]] name = "aiosignal" @@ -158,7 +158,7 @@ typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} [package.extras] doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\" and python_version < \"3.14\""] trio = ["trio (>=0.26.1)"] [[package]] @@ -184,12 +184,62 @@ files = [ ] [package.extras] -benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +cov = ["cloudpickle ; platform_python_implementation == \"CPython\"", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +dev = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] -tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""] + +[[package]] +name = "azure-common" +version = "1.1.28" +description = "Microsoft Azure Client Library for Python (Common)" +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3"}, + {file = "azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad"}, +] + +[[package]] +name = "azure-core" +version = "1.32.0" +description = "Microsoft Azure Core Library for Python" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "azure_core-1.32.0-py3-none-any.whl", hash = "sha256:eac191a0efb23bfa83fddf321b27b122b4ec847befa3091fa736a5c32c50d7b4"}, + {file = "azure_core-1.32.0.tar.gz", hash = "sha256:22b3c35d6b2dae14990f6c1be2912bf23ffe50b220e708a28ab1bb92b1c730e5"}, +] + +[package.dependencies] +requests = ">=2.21.0" +six = ">=1.11.0" +typing-extensions = ">=4.6.0" + +[package.extras] +aio = ["aiohttp (>=3.0)"] + +[[package]] +name = "azure-search-documents" +version = "11.5.2" +description = "Microsoft Azure Cognitive Search Client Library for Python" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "azure_search_documents-11.5.2-py3-none-any.whl", hash = "sha256:c949d011008a4b0bcee3db91132741b4e4d50ddb3f7e2f48944d949d4b413b11"}, + {file = "azure_search_documents-11.5.2.tar.gz", hash = "sha256:98977dd1fa4978d3b7d8891a0856b3becb6f02cc07ff2e1ea40b9c7254ada315"}, +] + +[package.dependencies] +azure-common = ">=1.1" +azure-core = ">=1.28.0" +isodate = ">=0.6.0" +typing-extensions = ">=4.6.0" [[package]] name = "backoff" @@ -868,7 +918,7 @@ httpcore = "==1.*" idna = "*" [package.extras] -brotli = ["brotli", "brotlicffi"] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] @@ -910,6 +960,18 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "isodate" +version = "0.7.2" +description = "An ISO 8601 date/time/duration parser and formatter" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15"}, + {file = "isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6"}, +] + [[package]] name = "isort" version = "5.13.2" @@ -1799,7 +1861,7 @@ typing-extensions = ">=4.12.2" [package.extras] email = ["email-validator (>=2.0.0)"] -timezone = ["tzdata"] +timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""] [[package]] name = "pydantic-core" @@ -2213,13 +2275,13 @@ files = [ ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"] -core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] +core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"] [[package]] name = "six" @@ -2449,7 +2511,7 @@ files = [ ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] diff --git a/pyproject.toml b/pyproject.toml index d09844e3..cc11d9e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ sqlalchemy = "^2.0.31" langchain-neo4j = "^0.4.0" neo4j = "^5.23.1" rank-bm25 = "^0.2.2" +azure-search-documents = "^11.5.0" psycopg2-binary = "^2.9.10" [tool.poetry.extras] diff --git a/tests/vector_stores/test_azure_ai_search.py b/tests/vector_stores/test_azure_ai_search.py new file mode 100644 index 00000000..77e38a3b --- /dev/null +++ b/tests/vector_stores/test_azure_ai_search.py @@ -0,0 +1,508 @@ +import json +from unittest.mock import Mock, patch +import pytest +from azure.core.exceptions import ResourceNotFoundError, HttpResponseError + +# Import the AzureAISearch class and OutputData model from your module. +from mem0.vector_stores.azure_ai_search import AzureAISearch + + +# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch. +@pytest.fixture +def mock_clients(): + with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \ + patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient: + # Create mocked instances for search and index clients. + mock_search_client = MockSearchClient.return_value + mock_index_client = MockIndexClient.return_value + + # Stub required methods on search_client. + mock_search_client.upload_documents = Mock() + mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}] + mock_search_client.search = Mock() + mock_search_client.delete_documents = Mock() + mock_search_client.delete_documents.return_value = [{"status": True, "id": "doc1"}] + mock_search_client.merge_or_upload_documents = Mock() + mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": "doc1"}] + mock_search_client.get_document = Mock() + mock_search_client.close = Mock() + + # Stub required methods on index_client. + mock_index_client.create_or_update_index = Mock() + mock_index_client.list_indexes = Mock(return_value=[]) + mock_index_client.list_index_names = Mock(return_value=["test-index"]) + mock_index_client.delete_index = Mock() + # For col_info() we assume get_index returns an object with name and fields attributes. + fake_index = Mock() + fake_index.name = "test-index" + fake_index.fields = ["id", "vector", "payload", "user_id", "run_id", "agent_id"] + mock_index_client.get_index = Mock(return_value=fake_index) + mock_index_client.close = Mock() + + yield mock_search_client, mock_index_client + +@pytest.fixture +def azure_ai_search_instance(mock_clients): + mock_search_client, mock_index_client = mock_clients + # Create an instance with dummy parameters. + instance = AzureAISearch( + service_name="test-service", + collection_name="test-index", + api_key="test-api-key", + embedding_model_dims=3, + compression_type="binary", # testing binary quantization option + use_float16=True + ) + # Return instance and clients for verification. + return instance, mock_search_client, mock_index_client + +# --- Original tests --- + +def test_create_col(azure_ai_search_instance): + instance, mock_search_client, mock_index_client = azure_ai_search_instance + # Upon initialization, create_col should be called. + mock_index_client.create_or_update_index.assert_called_once() + # Optionally, you could inspect the call arguments for vector type. + +def test_insert(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + vectors = [[0.1, 0.2, 0.3]] + payloads = [{"user_id": "user1", "run_id": "run1"}] + ids = ["doc1"] + + instance.insert(vectors, payloads, ids) + + mock_search_client.upload_documents.assert_called_once() + args, _ = mock_search_client.upload_documents.call_args + documents = args[0] + # Update expected_doc to include extra fields from payload. + expected_doc = { + "id": "doc1", + "vector": [0.1, 0.2, 0.3], + "payload": json.dumps({"user_id": "user1", "run_id": "run1"}), + "user_id": "user1", + "run_id": "run1" + } + assert documents[0] == expected_doc + +def test_search_preFilter(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + # Setup a fake search result returned by the mocked search method. + fake_result = { + "id": "doc1", + "@search.score": 0.95, + "payload": json.dumps({"user_id": "user1"}) + } + # Configure the mock to return an iterator (list) with fake_result. + mock_search_client.search.return_value = [fake_result] + + query_vector = [0.1, 0.2, 0.3] + results = instance.search(query_vector, limit=1, filters={"user_id": "user1"}, vector_filter_mode="preFilter") + + # Verify that the search method was called with vector_filter_mode="preFilter". + mock_search_client.search.assert_called_once() + _, called_kwargs = mock_search_client.search.call_args + assert called_kwargs.get("vector_filter_mode") == "preFilter" + + # Verify that the output is parsed correctly. + assert len(results) == 1 + assert results[0].id == "doc1" + assert results[0].score == 0.95 + assert results[0].payload == {"user_id": "user1"} + +def test_search_postFilter(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + # Setup a fake search result for postFilter. + fake_result = { + "id": "doc2", + "@search.score": 0.85, + "payload": json.dumps({"user_id": "user2"}) + } + mock_search_client.search.return_value = [fake_result] + + query_vector = [0.4, 0.5, 0.6] + results = instance.search(query_vector, limit=1, filters={"user_id": "user2"}, vector_filter_mode="postFilter") + + mock_search_client.search.assert_called_once() + _, called_kwargs = mock_search_client.search.call_args + assert called_kwargs.get("vector_filter_mode") == "postFilter" + + assert len(results) == 1 + assert results[0].id == "doc2" + assert results[0].score == 0.85 + assert results[0].payload == {"user_id": "user2"} + +def test_delete(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + vector_id = "doc1" + # Set delete_documents to return an iterable with a successful response. + mock_search_client.delete_documents.return_value = [{"status": True, "id": vector_id}] + instance.delete(vector_id) + mock_search_client.delete_documents.assert_called_once_with(documents=[{"id": vector_id}]) + +def test_update(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + vector_id = "doc1" + new_vector = [0.7, 0.8, 0.9] + new_payload = {"user_id": "updated"} + # Set merge_or_upload_documents to return an iterable with a successful response. + mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": vector_id}] + instance.update(vector_id, vector=new_vector, payload=new_payload) + mock_search_client.merge_or_upload_documents.assert_called_once() + kwargs = mock_search_client.merge_or_upload_documents.call_args.kwargs + document = kwargs["documents"][0] + assert document["id"] == vector_id + assert document["vector"] == new_vector + assert document["payload"] == json.dumps(new_payload) + # The update method will also add the 'user_id' field. + assert document["user_id"] == "updated" + +def test_get(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + fake_result = { + "id": "doc1", + "payload": json.dumps({"user_id": "user1"}) + } + mock_search_client.get_document.return_value = fake_result + result = instance.get("doc1") + mock_search_client.get_document.assert_called_once_with(key="doc1") + assert result.id == "doc1" + assert result.payload == {"user_id": "user1"} + assert result.score is None + +def test_list(azure_ai_search_instance): + instance, mock_search_client, _ = azure_ai_search_instance + fake_result = { + "id": "doc1", + "@search.score": 0.99, + "payload": json.dumps({"user_id": "user1"}) + } + mock_search_client.search.return_value = [fake_result] + # Call list with a simple filter. + results = instance.list(filters={"user_id": "user1"}, limit=1) + # Verify the search method was called with the proper parameters. + expected_filter = instance._build_filter_expression({"user_id": "user1"}) + mock_search_client.search.assert_called_once_with( + search_text="*", + filter=expected_filter, + top=1 + ) + assert isinstance(results, list) + assert len(results) == 1 + assert results[0].id == "doc1" + +# --- New tests for practical end-user scenarios --- + +def test_bulk_insert(azure_ai_search_instance): + """Test inserting a batch of documents (common for initial data loading).""" + instance, mock_search_client, _ = azure_ai_search_instance + + # Create a batch of 10 documents + num_docs = 10 + vectors = [[0.1, 0.2, 0.3] for _ in range(num_docs)] + payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)] + ids = [f"doc{i}" for i in range(num_docs)] + + # Configure mock to return success for all documents + mock_search_client.upload_documents.return_value = [ + {"status": True, "id": id_val} for id_val in ids + ] + + # Insert the batch + instance.insert(vectors, payloads, ids) + + # Verify the call + mock_search_client.upload_documents.assert_called_once() + args, _ = mock_search_client.upload_documents.call_args + documents = args[0] + assert len(documents) == num_docs + + # Verify the first and last document + assert documents[0]["id"] == "doc0" + assert documents[-1]["id"] == f"doc{num_docs-1}" + + +def test_insert_error_handling(azure_ai_search_instance): + """Test how the class handles Azure errors during insertion.""" + instance, mock_search_client, _ = azure_ai_search_instance + + # Configure mock to return a failure for one document + mock_search_client.upload_documents.return_value = [ + {"status": False, "id": "doc1", "errorMessage": "Azure error"} + ] + + vectors = [[0.1, 0.2, 0.3]] + payloads = [{"user_id": "user1"}] + ids = ["doc1"] + + # Exception should be raised + with pytest.raises(Exception) as exc_info: + instance.insert(vectors, payloads, ids) + + assert "Insert failed" in str(exc_info.value) + + +def test_search_with_complex_filters(azure_ai_search_instance): + """Test searching with multiple filter conditions as a user might need.""" + instance, mock_search_client, _ = azure_ai_search_instance + + # Configure mock response + mock_search_client.search.return_value = [ + { + "id": "doc1", + "@search.score": 0.95, + "payload": json.dumps({"user_id": "user1", "run_id": "run123", "agent_id": "agent456"}) + } + ] + + # Search with multiple filters (common in multi-tenant or segmented applications) + filters = { + "user_id": "user1", + "run_id": "run123", + "agent_id": "agent456" + } + results = instance.search([0.1, 0.2, 0.3], filters=filters) + + # Verify search was called with the correct filter expression + mock_search_client.search.assert_called_once() + _, kwargs = mock_search_client.search.call_args + assert "filter" in kwargs + + # The filter should contain all three conditions + filter_expr = kwargs["filter"] + assert "user_id eq 'user1'" in filter_expr + assert "run_id eq 'run123'" in filter_expr + assert "agent_id eq 'agent456'" in filter_expr + assert " and " in filter_expr # Conditions should be joined by AND + + +def test_empty_search_results(azure_ai_search_instance): + """Test behavior when search returns no results (common edge case).""" + instance, mock_search_client, _ = azure_ai_search_instance + + # Configure mock to return empty results + mock_search_client.search.return_value = [] + + # Search with a non-matching query + results = instance.search([0.9, 0.9, 0.9], limit=5) + + # Verify result handling + assert len(results) == 0 + + +def test_get_nonexistent_document(azure_ai_search_instance): + """Test behavior when getting a document that doesn't exist (should handle gracefully).""" + instance, mock_search_client, _ = azure_ai_search_instance + + # Configure mock to raise ResourceNotFoundError + mock_search_client.get_document.side_effect = ResourceNotFoundError("Document not found") + + # Get a non-existent document + result = instance.get("nonexistent_id") + + # Should return None instead of raising exception + assert result is None + + +def test_azure_service_error(azure_ai_search_instance): + """Test handling of Azure service errors (important for robustness).""" + instance, mock_search_client, _ = azure_ai_search_instance + + # Configure mock to raise HttpResponseError + http_error = HttpResponseError("Azure service is unavailable") + mock_search_client.search.side_effect = http_error + + # Attempt to search + with pytest.raises(HttpResponseError): + instance.search([0.1, 0.2, 0.3]) + + # Verify search was attempted + mock_search_client.search.assert_called_once() + + +def test_realistic_workflow(azure_ai_search_instance): + """Test a realistic workflow: insert → search → update → search again.""" + instance, mock_search_client, _ = azure_ai_search_instance + + # 1. Insert a document + vector = [0.1, 0.2, 0.3] + payload = {"user_id": "user1", "content": "Initial content"} + doc_id = "workflow_doc" + + mock_search_client.upload_documents.return_value = [{"status": True, "id": doc_id}] + instance.insert([vector], [payload], [doc_id]) + + # 2. Search for the document + mock_search_client.search.return_value = [ + { + "id": doc_id, + "@search.score": 0.95, + "payload": json.dumps(payload) + } + ] + results = instance.search(vector, filters={"user_id": "user1"}) + assert len(results) == 1 + assert results[0].id == doc_id + + # 3. Update the document + updated_payload = {"user_id": "user1", "content": "Updated content"} + mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": doc_id}] + instance.update(doc_id, payload=updated_payload) + + # 4. Search again to get updated document + mock_search_client.search.return_value = [ + { + "id": doc_id, + "@search.score": 0.95, + "payload": json.dumps(updated_payload) + } + ] + results = instance.search(vector, filters={"user_id": "user1"}) + assert len(results) == 1 + assert results[0].id == doc_id + assert results[0].payload["content"] == "Updated content" + + +def test_sanitize_special_characters(azure_ai_search_instance): + """Test that special characters in filter values are properly sanitized.""" + instance, mock_search_client, _ = azure_ai_search_instance + + # Configure mock response + mock_search_client.search.return_value = [ + { + "id": "doc1", + "@search.score": 0.95, + "payload": json.dumps({"user_id": "user's-data"}) + } + ] + + # Search with a filter that has special characters (common in real-world data) + filters = {"user_id": "user's-data"} + results = instance.search([0.1, 0.2, 0.3], filters=filters) + + # Verify search was called with properly escaped filter + mock_search_client.search.assert_called_once() + _, kwargs = mock_search_client.search.call_args + assert "filter" in kwargs + + # The filter should have properly escaped single quotes + filter_expr = kwargs["filter"] + assert "user_id eq 'user''s-data'" in filter_expr + + +def test_list_collections(azure_ai_search_instance): + """Test listing all collections/indexes (for management interfaces).""" + instance, _, mock_index_client = azure_ai_search_instance + + # List the collections + collections = instance.list_cols() + + # Verify the correct method was called + mock_index_client.list_index_names.assert_called_once() + + # Check the result + assert collections == ["test-index"] + + +def test_filter_with_numeric_values(azure_ai_search_instance): + """Test filtering with numeric values (common for faceted search).""" + instance, mock_search_client, _ = azure_ai_search_instance + + # Configure mock response + mock_search_client.search.return_value = [ + { + "id": "doc1", + "@search.score": 0.95, + "payload": json.dumps({"user_id": "user1", "count": 42}) + } + ] + + # Search with a numeric filter + # Note: In the actual implementation, numeric fields might need to be in the payload + filters = {"count": 42} + results = instance.search([0.1, 0.2, 0.3], filters=filters) + + # Verify the filter expression + mock_search_client.search.assert_called_once() + _, kwargs = mock_search_client.search.call_args + filter_expr = kwargs["filter"] + assert "count eq 42" in filter_expr # No quotes for numbers + + +def test_error_on_update_nonexistent(azure_ai_search_instance): + """Test behavior when updating a document that doesn't exist.""" + instance, mock_search_client, _ = azure_ai_search_instance + + # Configure mock to return a failure for the update + mock_search_client.merge_or_upload_documents.return_value = [ + {"status": False, "id": "nonexistent", "errorMessage": "Document not found"} + ] + + # Attempt to update a non-existent document + with pytest.raises(Exception) as exc_info: + instance.update("nonexistent", payload={"new": "data"}) + + assert "Update failed" in str(exc_info.value) + + +def test_different_compression_types(): + """Test creating instances with different compression types (important for performance tuning).""" + with patch("mem0.vector_stores.azure_ai_search.SearchClient"), \ + patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"): + + # Test with scalar compression + scalar_instance = AzureAISearch( + service_name="test-service", + collection_name="scalar-index", + api_key="test-api-key", + embedding_model_dims=3, + compression_type="scalar", + use_float16=False + ) + + # Test with no compression + no_compression_instance = AzureAISearch( + service_name="test-service", + collection_name="no-compression-index", + api_key="test-api-key", + embedding_model_dims=3, + compression_type=None, + use_float16=False + ) + + # No assertions needed - we're just verifying that initialization doesn't fail + + +def test_high_dimensional_vectors(): + """Test handling of high-dimensional vectors typical in AI embeddings.""" + with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \ + patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"): + + # Configure the mock client + mock_search_client = MockSearchClient.return_value + mock_search_client.upload_documents = Mock() + mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}] + + # Create an instance with higher dimensions like those from embedding models + high_dim_instance = AzureAISearch( + service_name="test-service", + collection_name="high-dim-index", + api_key="test-api-key", + embedding_model_dims=1536, # Common for models like OpenAI's embeddings + compression_type="binary", # Compression often used with high-dim vectors + use_float16=True # Reduced precision often used for memory efficiency + ) + + # Create a high-dimensional vector (stub with zeros for testing) + high_dim_vector = [0.0] * 1536 + payload = {"user_id": "user1"} + doc_id = "high_dim_doc" + + # Insert the document + high_dim_instance.insert([high_dim_vector], [payload], [doc_id]) + + # Verify the insert was called with the full vector + mock_search_client.upload_documents.assert_called_once() + args, _ = mock_search_client.upload_documents.call_args + documents = args[0] + assert len(documents[0]["vector"]) == 1536 \ No newline at end of file