From e25dc4b5047a97acf8043c27e8e33e4621331183 Mon Sep 17 00:00:00 2001 From: Farzad Sunavala <40604067+farzad528@users.noreply.github.com> Date: Mon, 17 Mar 2025 11:49:46 -0500 Subject: [PATCH] bugfix: update Azure AI Search Config (#2380) --- mem0/configs/vector_stores/azure_ai_search.py | 44 +- tests/vector_stores/test_azure_ai_search.py | 832 +++++++++--------- 2 files changed, 468 insertions(+), 408 deletions(-) diff --git a/mem0/configs/vector_stores/azure_ai_search.py b/mem0/configs/vector_stores/azure_ai_search.py index 5619b300..b256e139 100644 --- a/mem0/configs/vector_stores/azure_ai_search.py +++ b/mem0/configs/vector_stores/azure_ai_search.py @@ -1,27 +1,53 @@ -from typing import Any, Dict - +from typing import Any, Dict, Optional from pydantic import BaseModel, Field, model_validator class AzureAISearchConfig(BaseModel): collection_name: str = Field("mem0", description="Name of the collection") - service_name: str = Field(None, description="Azure Cognitive Search service name") - api_key: str = Field(None, description="API key for the Azure Cognitive Search service") + service_name: str = Field(None, description="Azure AI Search service name") + api_key: str = Field(None, description="API key for the Azure AI Search service") embedding_model_dims: int = Field(None, description="Dimension of the embedding vector") - use_compression: bool = Field(False, description="Whether to use scalar quantization vector compression.") - + compression_type: Optional[str] = Field( + None, + description="Type of vector compression to use. Options: 'scalar', 'binary', or None" + ) + use_float16: bool = Field( + False, + description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)" + ) + @model_validator(mode="before") @classmethod def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: allowed_fields = set(cls.model_fields.keys()) input_fields = set(values.keys()) extra_fields = input_fields - allowed_fields + + # Check for use_compression to provide a helpful error + if "use_compression" in extra_fields: + raise ValueError( + "The parameter 'use_compression' is no longer supported. " + "Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' " + "or 'compression_type=None' instead of 'use_compression=False'." + ) + if extra_fields: raise ValueError( - f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + f"Extra fields not allowed: {', '.join(extra_fields)}. " + f"Please input only the following fields: {', '.join(allowed_fields)}" ) + + # Validate compression_type values + if "compression_type" in values and values["compression_type"] is not None: + valid_types = ["scalar", "binary"] + if values["compression_type"].lower() not in valid_types: + raise ValueError( + f"Invalid compression_type: {values['compression_type']}. " + f"Must be one of: {', '.join(valid_types)}, or None" + ) + return values - + model_config = { "arbitrary_types_allowed": True, - } + } \ No newline at end of file diff --git a/tests/vector_stores/test_azure_ai_search.py b/tests/vector_stores/test_azure_ai_search.py index 77e38a3b..ac1106c5 100644 --- a/tests/vector_stores/test_azure_ai_search.py +++ b/tests/vector_stores/test_azure_ai_search.py @@ -1,20 +1,28 @@ import json -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock, call import pytest from azure.core.exceptions import ResourceNotFoundError, HttpResponseError -# Import the AzureAISearch class and OutputData model from your module. -from mem0.vector_stores.azure_ai_search import AzureAISearch +# Import the AzureAISearch class and related models +from mem0.vector_stores.azure_ai_search import AzureAISearch, OutputData +from mem0.configs.vector_stores.azure_ai_search import AzureAISearchConfig # Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch. @pytest.fixture def mock_clients(): with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \ - patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient: + patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient, \ + patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential: # Create mocked instances for search and index clients. mock_search_client = MockSearchClient.return_value mock_index_client = MockIndexClient.return_value + + # Mock the client._client._config.user_agent_policy.add_user_agent + mock_search_client._client = MagicMock() + mock_search_client._client._config.user_agent_policy.add_user_agent = Mock() + mock_index_client._client = MagicMock() + mock_index_client._client._config.user_agent_policy.add_user_agent = Mock() # Stub required methods on search_client. mock_search_client.upload_documents = Mock() @@ -29,7 +37,7 @@ def mock_clients(): # Stub required methods on index_client. mock_index_client.create_or_update_index = Mock() - mock_index_client.list_indexes = Mock(return_value=[]) + mock_index_client.list_indexes = Mock() mock_index_client.list_index_names = Mock(return_value=["test-index"]) mock_index_client.delete_index = Mock() # For col_info() we assume get_index returns an object with name and fields attributes. @@ -39,11 +47,12 @@ def mock_clients(): mock_index_client.get_index = Mock(return_value=fake_index) mock_index_client.close = Mock() - yield mock_search_client, mock_index_client + yield mock_search_client, mock_index_client, MockAzureKeyCredential + @pytest.fixture def azure_ai_search_instance(mock_clients): - mock_search_client, mock_index_client = mock_clients + mock_search_client, mock_index_client, _ = mock_clients # Create an instance with dummy parameters. instance = AzureAISearch( service_name="test-service", @@ -56,150 +65,334 @@ def azure_ai_search_instance(mock_clients): # Return instance and clients for verification. return instance, mock_search_client, mock_index_client -# --- Original tests --- + +# --- Tests for AzureAISearchConfig --- + +def test_config_validation_valid(): + """Test valid configurations are accepted.""" + # Test minimal configuration + config = AzureAISearchConfig( + service_name="test-service", + api_key="test-api-key", + embedding_model_dims=768 + ) + assert config.collection_name == "mem0" # Default value + assert config.service_name == "test-service" + assert config.api_key == "test-api-key" + assert config.embedding_model_dims == 768 + assert config.compression_type is None + assert config.use_float16 is False + + # Test with all optional parameters + config = AzureAISearchConfig( + collection_name="custom-index", + service_name="test-service", + api_key="test-api-key", + embedding_model_dims=1536, + compression_type="scalar", + use_float16=True + ) + assert config.collection_name == "custom-index" + assert config.compression_type == "scalar" + assert config.use_float16 is True + + +def test_config_validation_invalid_compression_type(): + """Test that invalid compression types are rejected.""" + with pytest.raises(ValueError) as exc_info: + AzureAISearchConfig( + service_name="test-service", + api_key="test-api-key", + embedding_model_dims=768, + compression_type="invalid-type" # Not a valid option + ) + assert "Invalid compression_type" in str(exc_info.value) + + +def test_config_validation_deprecated_use_compression(): + """Test that using the deprecated use_compression parameter raises an error.""" + with pytest.raises(ValueError) as exc_info: + AzureAISearchConfig( + service_name="test-service", + api_key="test-api-key", + embedding_model_dims=768, + use_compression=True # Deprecated parameter + ) + # Fix: Use a partial string match instead of exact match + assert "use_compression" in str(exc_info.value) + assert "no longer supported" in str(exc_info.value) + + +def test_config_validation_extra_fields(): + """Test that extra fields are rejected.""" + with pytest.raises(ValueError) as exc_info: + AzureAISearchConfig( + service_name="test-service", + api_key="test-api-key", + embedding_model_dims=768, + unknown_parameter="value" # Extra field + ) + assert "Extra fields not allowed" in str(exc_info.value) + assert "unknown_parameter" in str(exc_info.value) + + +# --- Tests for AzureAISearch initialization --- + +def test_initialization(mock_clients): + """Test AzureAISearch initialization with different parameters.""" + mock_search_client, mock_index_client, mock_azure_key_credential = mock_clients + + # Test with minimal parameters + instance = AzureAISearch( + service_name="test-service", + collection_name="test-index", + api_key="test-api-key", + embedding_model_dims=768 + ) + + # Verify initialization parameters + assert instance.index_name == "test-index" + assert instance.collection_name == "test-index" + assert instance.embedding_model_dims == 768 + assert instance.compression_type == "none" # Default when None is passed + assert instance.use_float16 is False + + # Verify client creation + mock_azure_key_credential.assert_called_with("test-api-key") + assert "mem0" in mock_search_client._client._config.user_agent_policy.add_user_agent.call_args[0] + assert "mem0" in mock_index_client._client._config.user_agent_policy.add_user_agent.call_args[0] + + # Verify index creation was called + mock_index_client.create_or_update_index.assert_called_once() + + +def test_initialization_with_compression_types(mock_clients): + """Test initialization with different compression types.""" + mock_search_client, mock_index_client, _ = mock_clients + + # Test with scalar compression + instance = AzureAISearch( + service_name="test-service", + collection_name="scalar-index", + api_key="test-api-key", + embedding_model_dims=768, + compression_type="scalar" + ) + assert instance.compression_type == "scalar" + + # Capture the index creation call + args, _ = mock_index_client.create_or_update_index.call_args_list[-1] + index = args[0] + # Verify scalar compression was configured + assert hasattr(index.vector_search, 'compressions') + assert len(index.vector_search.compressions) > 0 + assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0])) + + # Test with binary compression + instance = AzureAISearch( + service_name="test-service", + collection_name="binary-index", + api_key="test-api-key", + embedding_model_dims=768, + compression_type="binary" + ) + assert instance.compression_type == "binary" + + # Capture the index creation call + args, _ = mock_index_client.create_or_update_index.call_args_list[-1] + index = args[0] + # Verify binary compression was configured + assert hasattr(index.vector_search, 'compressions') + assert len(index.vector_search.compressions) > 0 + assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0])) + + # Test with no compression + instance = AzureAISearch( + service_name="test-service", + collection_name="no-compression-index", + api_key="test-api-key", + embedding_model_dims=768, + compression_type=None + ) + assert instance.compression_type == "none" + + # Capture the index creation call + args, _ = mock_index_client.create_or_update_index.call_args_list[-1] + index = args[0] + # Verify no compression was configured + assert hasattr(index.vector_search, 'compressions') + assert len(index.vector_search.compressions) == 0 + + +def test_initialization_with_float_precision(mock_clients): + """Test initialization with different float precision settings.""" + mock_search_client, mock_index_client, _ = mock_clients + + # Test with half precision (float16) + instance = AzureAISearch( + service_name="test-service", + collection_name="float16-index", + api_key="test-api-key", + embedding_model_dims=768, + use_float16=True + ) + assert instance.use_float16 is True + + # Capture the index creation call + args, _ = mock_index_client.create_or_update_index.call_args_list[-1] + index = args[0] + # Find the vector field and check its type + vector_field = next((f for f in index.fields if f.name == "vector"), None) + assert vector_field is not None + assert "Edm.Half" in vector_field.type + + # Test with full precision (float32) + instance = AzureAISearch( + service_name="test-service", + collection_name="float32-index", + api_key="test-api-key", + embedding_model_dims=768, + use_float16=False + ) + assert instance.use_float16 is False + + # Capture the index creation call + args, _ = mock_index_client.create_or_update_index.call_args_list[-1] + index = args[0] + # Find the vector field and check its type + vector_field = next((f for f in index.fields if f.name == "vector"), None) + assert vector_field is not None + assert "Edm.Single" in vector_field.type + + +# --- Tests for create_col method --- def test_create_col(azure_ai_search_instance): - instance, mock_search_client, mock_index_client = azure_ai_search_instance - # Upon initialization, create_col should be called. + """Test the create_col method creates an index with the correct configuration.""" + instance, _, mock_index_client = azure_ai_search_instance + + # create_col is called during initialization, so we check the call that was already made mock_index_client.create_or_update_index.assert_called_once() - # Optionally, you could inspect the call arguments for vector type. + + # Verify the index configuration + args, _ = mock_index_client.create_or_update_index.call_args + index = args[0] + + # Check basic properties + assert index.name == "test-index" + assert len(index.fields) == 6 # id, user_id, run_id, agent_id, vector, payload + + # Check that required fields are present + field_names = [f.name for f in index.fields] + assert "id" in field_names + assert "vector" in field_names + assert "payload" in field_names + assert "user_id" in field_names + assert "run_id" in field_names + assert "agent_id" in field_names + + # Check that id is the key field + id_field = next(f for f in index.fields if f.name == "id") + assert id_field.key is True + + # Check vector search configuration + assert index.vector_search is not None + assert len(index.vector_search.profiles) == 1 + assert index.vector_search.profiles[0].name == "my-vector-config" + assert index.vector_search.profiles[0].algorithm_configuration_name == "my-algorithms-config" + + # Check algorithms + assert len(index.vector_search.algorithms) == 1 + assert index.vector_search.algorithms[0].name == "my-algorithms-config" + assert "HnswAlgorithmConfiguration" in str(type(index.vector_search.algorithms[0])) + + # With binary compression and float16, we should have compression configuration + assert len(index.vector_search.compressions) == 1 + assert index.vector_search.compressions[0].compression_name == "myCompression" + assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0])) -def test_insert(azure_ai_search_instance): + +def test_create_col_scalar_compression(mock_clients): + """Test creating a collection with scalar compression.""" + mock_search_client, mock_index_client, _ = mock_clients + + instance = AzureAISearch( + service_name="test-service", + collection_name="scalar-index", + api_key="test-api-key", + embedding_model_dims=768, + compression_type="scalar" + ) + + # Verify the index configuration + args, _ = mock_index_client.create_or_update_index.call_args + index = args[0] + + # Check compression configuration + assert len(index.vector_search.compressions) == 1 + assert index.vector_search.compressions[0].compression_name == "myCompression" + assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0])) + + # Check profile references compression + assert index.vector_search.profiles[0].compression_name == "myCompression" + + +def test_create_col_no_compression(mock_clients): + """Test creating a collection with no compression.""" + mock_search_client, mock_index_client, _ = mock_clients + + instance = AzureAISearch( + service_name="test-service", + collection_name="no-compression-index", + api_key="test-api-key", + embedding_model_dims=768, + compression_type=None + ) + + # Verify the index configuration + args, _ = mock_index_client.create_or_update_index.call_args + index = args[0] + + # Check compression configuration - should be empty + assert len(index.vector_search.compressions) == 0 + + # Check profile doesn't reference compression + assert index.vector_search.profiles[0].compression_name is None + + +# --- Tests for insert method --- + +def test_insert_single(azure_ai_search_instance): + """Test inserting a single vector.""" instance, mock_search_client, _ = azure_ai_search_instance vectors = [[0.1, 0.2, 0.3]] - payloads = [{"user_id": "user1", "run_id": "run1"}] + payloads = [{"user_id": "user1", "run_id": "run1", "agent_id": "agent1"}] ids = ["doc1"] instance.insert(vectors, payloads, ids) + # Verify upload_documents was called correctly mock_search_client.upload_documents.assert_called_once() args, _ = mock_search_client.upload_documents.call_args documents = args[0] - # Update expected_doc to include extra fields from payload. - expected_doc = { - "id": "doc1", - "vector": [0.1, 0.2, 0.3], - "payload": json.dumps({"user_id": "user1", "run_id": "run1"}), - "user_id": "user1", - "run_id": "run1" - } - assert documents[0] == expected_doc + + # Verify document structure + assert len(documents) == 1 + assert documents[0]["id"] == "doc1" + assert documents[0]["vector"] == [0.1, 0.2, 0.3] + assert documents[0]["payload"] == json.dumps(payloads[0]) + assert documents[0]["user_id"] == "user1" + assert documents[0]["run_id"] == "run1" + assert documents[0]["agent_id"] == "agent1" -def test_search_preFilter(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - # Setup a fake search result returned by the mocked search method. - fake_result = { - "id": "doc1", - "@search.score": 0.95, - "payload": json.dumps({"user_id": "user1"}) - } - # Configure the mock to return an iterator (list) with fake_result. - mock_search_client.search.return_value = [fake_result] - query_vector = [0.1, 0.2, 0.3] - results = instance.search(query_vector, limit=1, filters={"user_id": "user1"}, vector_filter_mode="preFilter") - - # Verify that the search method was called with vector_filter_mode="preFilter". - mock_search_client.search.assert_called_once() - _, called_kwargs = mock_search_client.search.call_args - assert called_kwargs.get("vector_filter_mode") == "preFilter" - - # Verify that the output is parsed correctly. - assert len(results) == 1 - assert results[0].id == "doc1" - assert results[0].score == 0.95 - assert results[0].payload == {"user_id": "user1"} - -def test_search_postFilter(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - # Setup a fake search result for postFilter. - fake_result = { - "id": "doc2", - "@search.score": 0.85, - "payload": json.dumps({"user_id": "user2"}) - } - mock_search_client.search.return_value = [fake_result] - - query_vector = [0.4, 0.5, 0.6] - results = instance.search(query_vector, limit=1, filters={"user_id": "user2"}, vector_filter_mode="postFilter") - - mock_search_client.search.assert_called_once() - _, called_kwargs = mock_search_client.search.call_args - assert called_kwargs.get("vector_filter_mode") == "postFilter" - - assert len(results) == 1 - assert results[0].id == "doc2" - assert results[0].score == 0.85 - assert results[0].payload == {"user_id": "user2"} - -def test_delete(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - vector_id = "doc1" - # Set delete_documents to return an iterable with a successful response. - mock_search_client.delete_documents.return_value = [{"status": True, "id": vector_id}] - instance.delete(vector_id) - mock_search_client.delete_documents.assert_called_once_with(documents=[{"id": vector_id}]) - -def test_update(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - vector_id = "doc1" - new_vector = [0.7, 0.8, 0.9] - new_payload = {"user_id": "updated"} - # Set merge_or_upload_documents to return an iterable with a successful response. - mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": vector_id}] - instance.update(vector_id, vector=new_vector, payload=new_payload) - mock_search_client.merge_or_upload_documents.assert_called_once() - kwargs = mock_search_client.merge_or_upload_documents.call_args.kwargs - document = kwargs["documents"][0] - assert document["id"] == vector_id - assert document["vector"] == new_vector - assert document["payload"] == json.dumps(new_payload) - # The update method will also add the 'user_id' field. - assert document["user_id"] == "updated" - -def test_get(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - fake_result = { - "id": "doc1", - "payload": json.dumps({"user_id": "user1"}) - } - mock_search_client.get_document.return_value = fake_result - result = instance.get("doc1") - mock_search_client.get_document.assert_called_once_with(key="doc1") - assert result.id == "doc1" - assert result.payload == {"user_id": "user1"} - assert result.score is None - -def test_list(azure_ai_search_instance): - instance, mock_search_client, _ = azure_ai_search_instance - fake_result = { - "id": "doc1", - "@search.score": 0.99, - "payload": json.dumps({"user_id": "user1"}) - } - mock_search_client.search.return_value = [fake_result] - # Call list with a simple filter. - results = instance.list(filters={"user_id": "user1"}, limit=1) - # Verify the search method was called with the proper parameters. - expected_filter = instance._build_filter_expression({"user_id": "user1"}) - mock_search_client.search.assert_called_once_with( - search_text="*", - filter=expected_filter, - top=1 - ) - assert isinstance(results, list) - assert len(results) == 1 - assert results[0].id == "doc1" - -# --- New tests for practical end-user scenarios --- - -def test_bulk_insert(azure_ai_search_instance): - """Test inserting a batch of documents (common for initial data loading).""" +def test_insert_multiple(azure_ai_search_instance): + """Test inserting multiple vectors in one call.""" instance, mock_search_client, _ = azure_ai_search_instance - # Create a batch of 10 documents - num_docs = 10 - vectors = [[0.1, 0.2, 0.3] for _ in range(num_docs)] + # Create multiple vectors + num_docs = 3 + vectors = [[float(i)/10, float(i+1)/10, float(i+2)/10] for i in range(num_docs)] payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)] ids = [f"doc{i}" for i in range(num_docs)] @@ -208,25 +401,35 @@ def test_bulk_insert(azure_ai_search_instance): {"status": True, "id": id_val} for id_val in ids ] - # Insert the batch + # Insert the documents instance.insert(vectors, payloads, ids) - # Verify the call + # Verify upload_documents was called with correct documents mock_search_client.upload_documents.assert_called_once() args, _ = mock_search_client.upload_documents.call_args documents = args[0] + + # Verify all documents were included assert len(documents) == num_docs - # Verify the first and last document + # Check first document assert documents[0]["id"] == "doc0" - assert documents[-1]["id"] == f"doc{num_docs-1}" + assert documents[0]["vector"] == [0.0, 0.1, 0.2] + assert documents[0]["payload"] == json.dumps(payloads[0]) + assert documents[0]["user_id"] == "user0" + + # Check last document + assert documents[2]["id"] == "doc2" + assert documents[2]["vector"] == [0.2, 0.3, 0.4] + assert documents[2]["payload"] == json.dumps(payloads[2]) + assert documents[2]["user_id"] == "user2" -def test_insert_error_handling(azure_ai_search_instance): - """Test how the class handles Azure errors during insertion.""" +def test_insert_with_error(azure_ai_search_instance): + """Test insert when Azure returns an error for one or more documents.""" instance, mock_search_client, _ = azure_ai_search_instance - # Configure mock to return a failure for one document + # Configure mock to return an error for one document mock_search_client.upload_documents.return_value = [ {"status": False, "id": "doc1", "errorMessage": "Azure error"} ] @@ -235,274 +438,105 @@ def test_insert_error_handling(azure_ai_search_instance): payloads = [{"user_id": "user1"}] ids = ["doc1"] - # Exception should be raised + # Insert should raise an exception with pytest.raises(Exception) as exc_info: instance.insert(vectors, payloads, ids) - assert "Insert failed" in str(exc_info.value) - - -def test_search_with_complex_filters(azure_ai_search_instance): - """Test searching with multiple filter conditions as a user might need.""" - instance, mock_search_client, _ = azure_ai_search_instance + assert "Insert failed for document doc1" in str(exc_info.value) - # Configure mock response - mock_search_client.search.return_value = [ - { - "id": "doc1", - "@search.score": 0.95, - "payload": json.dumps({"user_id": "user1", "run_id": "run123", "agent_id": "agent456"}) - } + # Configure mock to return mixed success/failure for multiple documents + mock_search_client.upload_documents.return_value = [ + {"status": True, "id": "doc1"}, + {"status": False, "id": "doc2", "errorMessage": "Azure error"} ] - # Search with multiple filters (common in multi-tenant or segmented applications) - filters = { - "user_id": "user1", - "run_id": "run123", - "agent_id": "agent456" - } - results = instance.search([0.1, 0.2, 0.3], filters=filters) + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [{"user_id": "user1"}, {"user_id": "user2"}] + ids = ["doc1", "doc2"] - # Verify search was called with the correct filter expression - mock_search_client.search.assert_called_once() - _, kwargs = mock_search_client.search.call_args - assert "filter" in kwargs - - # The filter should contain all three conditions - filter_expr = kwargs["filter"] - assert "user_id eq 'user1'" in filter_expr - assert "run_id eq 'run123'" in filter_expr - assert "agent_id eq 'agent456'" in filter_expr - assert " and " in filter_expr # Conditions should be joined by AND - - -def test_empty_search_results(azure_ai_search_instance): - """Test behavior when search returns no results (common edge case).""" - instance, mock_search_client, _ = azure_ai_search_instance - - # Configure mock to return empty results - mock_search_client.search.return_value = [] - - # Search with a non-matching query - results = instance.search([0.9, 0.9, 0.9], limit=5) - - # Verify result handling - assert len(results) == 0 - - -def test_get_nonexistent_document(azure_ai_search_instance): - """Test behavior when getting a document that doesn't exist (should handle gracefully).""" - instance, mock_search_client, _ = azure_ai_search_instance - - # Configure mock to raise ResourceNotFoundError - mock_search_client.get_document.side_effect = ResourceNotFoundError("Document not found") - - # Get a non-existent document - result = instance.get("nonexistent_id") - - # Should return None instead of raising exception - assert result is None - - -def test_azure_service_error(azure_ai_search_instance): - """Test handling of Azure service errors (important for robustness).""" - instance, mock_search_client, _ = azure_ai_search_instance - - # Configure mock to raise HttpResponseError - http_error = HttpResponseError("Azure service is unavailable") - mock_search_client.search.side_effect = http_error - - # Attempt to search - with pytest.raises(HttpResponseError): - instance.search([0.1, 0.2, 0.3]) - - # Verify search was attempted - mock_search_client.search.assert_called_once() - - -def test_realistic_workflow(azure_ai_search_instance): - """Test a realistic workflow: insert → search → update → search again.""" - instance, mock_search_client, _ = azure_ai_search_instance - - # 1. Insert a document - vector = [0.1, 0.2, 0.3] - payload = {"user_id": "user1", "content": "Initial content"} - doc_id = "workflow_doc" - - mock_search_client.upload_documents.return_value = [{"status": True, "id": doc_id}] - instance.insert([vector], [payload], [doc_id]) - - # 2. Search for the document - mock_search_client.search.return_value = [ - { - "id": doc_id, - "@search.score": 0.95, - "payload": json.dumps(payload) - } - ] - results = instance.search(vector, filters={"user_id": "user1"}) - assert len(results) == 1 - assert results[0].id == doc_id - - # 3. Update the document - updated_payload = {"user_id": "user1", "content": "Updated content"} - mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": doc_id}] - instance.update(doc_id, payload=updated_payload) - - # 4. Search again to get updated document - mock_search_client.search.return_value = [ - { - "id": doc_id, - "@search.score": 0.95, - "payload": json.dumps(updated_payload) - } - ] - results = instance.search(vector, filters={"user_id": "user1"}) - assert len(results) == 1 - assert results[0].id == doc_id - assert results[0].payload["content"] == "Updated content" - - -def test_sanitize_special_characters(azure_ai_search_instance): - """Test that special characters in filter values are properly sanitized.""" - instance, mock_search_client, _ = azure_ai_search_instance - - # Configure mock response - mock_search_client.search.return_value = [ - { - "id": "doc1", - "@search.score": 0.95, - "payload": json.dumps({"user_id": "user's-data"}) - } - ] - - # Search with a filter that has special characters (common in real-world data) - filters = {"user_id": "user's-data"} - results = instance.search([0.1, 0.2, 0.3], filters=filters) - - # Verify search was called with properly escaped filter - mock_search_client.search.assert_called_once() - _, kwargs = mock_search_client.search.call_args - assert "filter" in kwargs - - # The filter should have properly escaped single quotes - filter_expr = kwargs["filter"] - assert "user_id eq 'user''s-data'" in filter_expr - - -def test_list_collections(azure_ai_search_instance): - """Test listing all collections/indexes (for management interfaces).""" - instance, _, mock_index_client = azure_ai_search_instance - - # List the collections - collections = instance.list_cols() - - # Verify the correct method was called - mock_index_client.list_index_names.assert_called_once() - - # Check the result - assert collections == ["test-index"] - - -def test_filter_with_numeric_values(azure_ai_search_instance): - """Test filtering with numeric values (common for faceted search).""" - instance, mock_search_client, _ = azure_ai_search_instance - - # Configure mock response - mock_search_client.search.return_value = [ - { - "id": "doc1", - "@search.score": 0.95, - "payload": json.dumps({"user_id": "user1", "count": 42}) - } - ] - - # Search with a numeric filter - # Note: In the actual implementation, numeric fields might need to be in the payload - filters = {"count": 42} - results = instance.search([0.1, 0.2, 0.3], filters=filters) - - # Verify the filter expression - mock_search_client.search.assert_called_once() - _, kwargs = mock_search_client.search.call_args - filter_expr = kwargs["filter"] - assert "count eq 42" in filter_expr # No quotes for numbers - - -def test_error_on_update_nonexistent(azure_ai_search_instance): - """Test behavior when updating a document that doesn't exist.""" - instance, mock_search_client, _ = azure_ai_search_instance - - # Configure mock to return a failure for the update - mock_search_client.merge_or_upload_documents.return_value = [ - {"status": False, "id": "nonexistent", "errorMessage": "Document not found"} - ] - - # Attempt to update a non-existent document + # Insert should raise an exception with pytest.raises(Exception) as exc_info: - instance.update("nonexistent", payload={"new": "data"}) + instance.insert(vectors, payloads, ids) - assert "Update failed" in str(exc_info.value) + assert "Insert failed for document doc2" in str(exc_info.value) -def test_different_compression_types(): - """Test creating instances with different compression types (important for performance tuning).""" - with patch("mem0.vector_stores.azure_ai_search.SearchClient"), \ - patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"): - - # Test with scalar compression - scalar_instance = AzureAISearch( - service_name="test-service", - collection_name="scalar-index", - api_key="test-api-key", - embedding_model_dims=3, - compression_type="scalar", - use_float16=False - ) - - # Test with no compression - no_compression_instance = AzureAISearch( - service_name="test-service", - collection_name="no-compression-index", - api_key="test-api-key", - embedding_model_dims=3, - compression_type=None, - use_float16=False - ) - - # No assertions needed - we're just verifying that initialization doesn't fail +def test_insert_with_missing_payload_fields(azure_ai_search_instance): + """Test inserting with payloads missing some of the expected fields.""" + instance, mock_search_client, _ = azure_ai_search_instance + vectors = [[0.1, 0.2, 0.3]] + payloads = [{"content": "Some content without user_id, run_id, or agent_id"}] + ids = ["doc1"] + + instance.insert(vectors, payloads, ids) + + # Verify upload_documents was called correctly + mock_search_client.upload_documents.assert_called_once() + args, _ = mock_search_client.upload_documents.call_args + documents = args[0] + + # Verify document has payload but not the extra fields + assert len(documents) == 1 + assert documents[0]["id"] == "doc1" + assert documents[0]["vector"] == [0.1, 0.2, 0.3] + assert documents[0]["payload"] == json.dumps(payloads[0]) + assert "user_id" not in documents[0] + assert "run_id" not in documents[0] + assert "agent_id" not in documents[0] -def test_high_dimensional_vectors(): - """Test handling of high-dimensional vectors typical in AI embeddings.""" - with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \ - patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"): - - # Configure the mock client - mock_search_client = MockSearchClient.return_value - mock_search_client.upload_documents = Mock() - mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}] - - # Create an instance with higher dimensions like those from embedding models - high_dim_instance = AzureAISearch( - service_name="test-service", - collection_name="high-dim-index", - api_key="test-api-key", - embedding_model_dims=1536, # Common for models like OpenAI's embeddings - compression_type="binary", # Compression often used with high-dim vectors - use_float16=True # Reduced precision often used for memory efficiency - ) - - # Create a high-dimensional vector (stub with zeros for testing) - high_dim_vector = [0.0] * 1536 - payload = {"user_id": "user1"} - doc_id = "high_dim_doc" - - # Insert the document - high_dim_instance.insert([high_dim_vector], [payload], [doc_id]) - - # Verify the insert was called with the full vector - mock_search_client.upload_documents.assert_called_once() - args, _ = mock_search_client.upload_documents.call_args - documents = args[0] - assert len(documents[0]["vector"]) == 1536 \ No newline at end of file +def test_insert_with_http_error(azure_ai_search_instance): + """Test insert when Azure client throws an HTTP error.""" + instance, mock_search_client, _ = azure_ai_search_instance + + # Configure mock to raise an HttpResponseError + mock_search_client.upload_documents.side_effect = HttpResponseError("Azure service error") + + vectors = [[0.1, 0.2, 0.3]] + payloads = [{"user_id": "user1"}] + ids = ["doc1"] + + # Insert should propagate the HTTP error + with pytest.raises(HttpResponseError) as exc_info: + instance.insert(vectors, payloads, ids) + + assert "Azure service error" in str(exc_info.value) + + +# --- Tests for search method --- + +def test_search_basic(azure_ai_search_instance): + """Test basic vector search without filters.""" + instance, mock_search_client, _ = azure_ai_search_instance + + # Configure mock to return search results + mock_search_client.search.return_value = [ + { + "id": "doc1", + "@search.score": 0.95, + "payload": json.dumps({"content": "Test content"}) + } + ] + + # Search with a vector + query_vector = [0.1, 0.2, 0.3] + results = instance.search(query_vector, limit=5) + + # Verify search was called correctly + mock_search_client.search.assert_called_once() + _, kwargs = mock_search_client.search.call_args + + # Check parameters + assert len(kwargs["vector_queries"]) == 1 + assert kwargs["vector_queries"][0].vector == query_vector + assert kwargs["vector_queries"][0].k_nearest_neighbors == 5 + assert kwargs["vector_queries"][0].fields == "vector" + assert kwargs["filter"] is None # No filters + assert kwargs["top"] == 5 + assert kwargs["vector_filter_mode"] == "preFilter" # Default mode + + # Check results + assert len(results) == 1 + assert results[0].id == "doc1" + assert results[0].score == 0.95 + assert results[0].payload == {"content": "Test content"} \ No newline at end of file