bugfix: update Azure AI Search Config (#2380)
This commit is contained in:
@@ -1,20 +1,28 @@
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, patch, MagicMock, call
|
||||
import pytest
|
||||
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError
|
||||
|
||||
# Import the AzureAISearch class and OutputData model from your module.
|
||||
from mem0.vector_stores.azure_ai_search import AzureAISearch
|
||||
# Import the AzureAISearch class and related models
|
||||
from mem0.vector_stores.azure_ai_search import AzureAISearch, OutputData
|
||||
from mem0.configs.vector_stores.azure_ai_search import AzureAISearchConfig
|
||||
|
||||
|
||||
# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch.
|
||||
@pytest.fixture
|
||||
def mock_clients():
|
||||
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
|
||||
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient:
|
||||
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient, \
|
||||
patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential:
|
||||
# Create mocked instances for search and index clients.
|
||||
mock_search_client = MockSearchClient.return_value
|
||||
mock_index_client = MockIndexClient.return_value
|
||||
|
||||
# Mock the client._client._config.user_agent_policy.add_user_agent
|
||||
mock_search_client._client = MagicMock()
|
||||
mock_search_client._client._config.user_agent_policy.add_user_agent = Mock()
|
||||
mock_index_client._client = MagicMock()
|
||||
mock_index_client._client._config.user_agent_policy.add_user_agent = Mock()
|
||||
|
||||
# Stub required methods on search_client.
|
||||
mock_search_client.upload_documents = Mock()
|
||||
@@ -29,7 +37,7 @@ def mock_clients():
|
||||
|
||||
# Stub required methods on index_client.
|
||||
mock_index_client.create_or_update_index = Mock()
|
||||
mock_index_client.list_indexes = Mock(return_value=[])
|
||||
mock_index_client.list_indexes = Mock()
|
||||
mock_index_client.list_index_names = Mock(return_value=["test-index"])
|
||||
mock_index_client.delete_index = Mock()
|
||||
# For col_info() we assume get_index returns an object with name and fields attributes.
|
||||
@@ -39,11 +47,12 @@ def mock_clients():
|
||||
mock_index_client.get_index = Mock(return_value=fake_index)
|
||||
mock_index_client.close = Mock()
|
||||
|
||||
yield mock_search_client, mock_index_client
|
||||
yield mock_search_client, mock_index_client, MockAzureKeyCredential
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_ai_search_instance(mock_clients):
|
||||
mock_search_client, mock_index_client = mock_clients
|
||||
mock_search_client, mock_index_client, _ = mock_clients
|
||||
# Create an instance with dummy parameters.
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
@@ -56,150 +65,334 @@ def azure_ai_search_instance(mock_clients):
|
||||
# Return instance and clients for verification.
|
||||
return instance, mock_search_client, mock_index_client
|
||||
|
||||
# --- Original tests ---
|
||||
|
||||
# --- Tests for AzureAISearchConfig ---
|
||||
|
||||
def test_config_validation_valid():
|
||||
"""Test valid configurations are accepted."""
|
||||
# Test minimal configuration
|
||||
config = AzureAISearchConfig(
|
||||
service_name="test-service",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768
|
||||
)
|
||||
assert config.collection_name == "mem0" # Default value
|
||||
assert config.service_name == "test-service"
|
||||
assert config.api_key == "test-api-key"
|
||||
assert config.embedding_model_dims == 768
|
||||
assert config.compression_type is None
|
||||
assert config.use_float16 is False
|
||||
|
||||
# Test with all optional parameters
|
||||
config = AzureAISearchConfig(
|
||||
collection_name="custom-index",
|
||||
service_name="test-service",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=1536,
|
||||
compression_type="scalar",
|
||||
use_float16=True
|
||||
)
|
||||
assert config.collection_name == "custom-index"
|
||||
assert config.compression_type == "scalar"
|
||||
assert config.use_float16 is True
|
||||
|
||||
|
||||
def test_config_validation_invalid_compression_type():
|
||||
"""Test that invalid compression types are rejected."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AzureAISearchConfig(
|
||||
service_name="test-service",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type="invalid-type" # Not a valid option
|
||||
)
|
||||
assert "Invalid compression_type" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_config_validation_deprecated_use_compression():
|
||||
"""Test that using the deprecated use_compression parameter raises an error."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AzureAISearchConfig(
|
||||
service_name="test-service",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
use_compression=True # Deprecated parameter
|
||||
)
|
||||
# Fix: Use a partial string match instead of exact match
|
||||
assert "use_compression" in str(exc_info.value)
|
||||
assert "no longer supported" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_config_validation_extra_fields():
|
||||
"""Test that extra fields are rejected."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AzureAISearchConfig(
|
||||
service_name="test-service",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
unknown_parameter="value" # Extra field
|
||||
)
|
||||
assert "Extra fields not allowed" in str(exc_info.value)
|
||||
assert "unknown_parameter" in str(exc_info.value)
|
||||
|
||||
|
||||
# --- Tests for AzureAISearch initialization ---
|
||||
|
||||
def test_initialization(mock_clients):
|
||||
"""Test AzureAISearch initialization with different parameters."""
|
||||
mock_search_client, mock_index_client, mock_azure_key_credential = mock_clients
|
||||
|
||||
# Test with minimal parameters
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="test-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768
|
||||
)
|
||||
|
||||
# Verify initialization parameters
|
||||
assert instance.index_name == "test-index"
|
||||
assert instance.collection_name == "test-index"
|
||||
assert instance.embedding_model_dims == 768
|
||||
assert instance.compression_type == "none" # Default when None is passed
|
||||
assert instance.use_float16 is False
|
||||
|
||||
# Verify client creation
|
||||
mock_azure_key_credential.assert_called_with("test-api-key")
|
||||
assert "mem0" in mock_search_client._client._config.user_agent_policy.add_user_agent.call_args[0]
|
||||
assert "mem0" in mock_index_client._client._config.user_agent_policy.add_user_agent.call_args[0]
|
||||
|
||||
# Verify index creation was called
|
||||
mock_index_client.create_or_update_index.assert_called_once()
|
||||
|
||||
|
||||
def test_initialization_with_compression_types(mock_clients):
|
||||
"""Test initialization with different compression types."""
|
||||
mock_search_client, mock_index_client, _ = mock_clients
|
||||
|
||||
# Test with scalar compression
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="scalar-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type="scalar"
|
||||
)
|
||||
assert instance.compression_type == "scalar"
|
||||
|
||||
# Capture the index creation call
|
||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||
index = args[0]
|
||||
# Verify scalar compression was configured
|
||||
assert hasattr(index.vector_search, 'compressions')
|
||||
assert len(index.vector_search.compressions) > 0
|
||||
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
||||
|
||||
# Test with binary compression
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="binary-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type="binary"
|
||||
)
|
||||
assert instance.compression_type == "binary"
|
||||
|
||||
# Capture the index creation call
|
||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||
index = args[0]
|
||||
# Verify binary compression was configured
|
||||
assert hasattr(index.vector_search, 'compressions')
|
||||
assert len(index.vector_search.compressions) > 0
|
||||
assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
||||
|
||||
# Test with no compression
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="no-compression-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type=None
|
||||
)
|
||||
assert instance.compression_type == "none"
|
||||
|
||||
# Capture the index creation call
|
||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||
index = args[0]
|
||||
# Verify no compression was configured
|
||||
assert hasattr(index.vector_search, 'compressions')
|
||||
assert len(index.vector_search.compressions) == 0
|
||||
|
||||
|
||||
def test_initialization_with_float_precision(mock_clients):
|
||||
"""Test initialization with different float precision settings."""
|
||||
mock_search_client, mock_index_client, _ = mock_clients
|
||||
|
||||
# Test with half precision (float16)
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="float16-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
use_float16=True
|
||||
)
|
||||
assert instance.use_float16 is True
|
||||
|
||||
# Capture the index creation call
|
||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||
index = args[0]
|
||||
# Find the vector field and check its type
|
||||
vector_field = next((f for f in index.fields if f.name == "vector"), None)
|
||||
assert vector_field is not None
|
||||
assert "Edm.Half" in vector_field.type
|
||||
|
||||
# Test with full precision (float32)
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="float32-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
use_float16=False
|
||||
)
|
||||
assert instance.use_float16 is False
|
||||
|
||||
# Capture the index creation call
|
||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||
index = args[0]
|
||||
# Find the vector field and check its type
|
||||
vector_field = next((f for f in index.fields if f.name == "vector"), None)
|
||||
assert vector_field is not None
|
||||
assert "Edm.Single" in vector_field.type
|
||||
|
||||
|
||||
# --- Tests for create_col method ---
|
||||
|
||||
def test_create_col(azure_ai_search_instance):
|
||||
instance, mock_search_client, mock_index_client = azure_ai_search_instance
|
||||
# Upon initialization, create_col should be called.
|
||||
"""Test the create_col method creates an index with the correct configuration."""
|
||||
instance, _, mock_index_client = azure_ai_search_instance
|
||||
|
||||
# create_col is called during initialization, so we check the call that was already made
|
||||
mock_index_client.create_or_update_index.assert_called_once()
|
||||
# Optionally, you could inspect the call arguments for vector type.
|
||||
|
||||
# Verify the index configuration
|
||||
args, _ = mock_index_client.create_or_update_index.call_args
|
||||
index = args[0]
|
||||
|
||||
# Check basic properties
|
||||
assert index.name == "test-index"
|
||||
assert len(index.fields) == 6 # id, user_id, run_id, agent_id, vector, payload
|
||||
|
||||
# Check that required fields are present
|
||||
field_names = [f.name for f in index.fields]
|
||||
assert "id" in field_names
|
||||
assert "vector" in field_names
|
||||
assert "payload" in field_names
|
||||
assert "user_id" in field_names
|
||||
assert "run_id" in field_names
|
||||
assert "agent_id" in field_names
|
||||
|
||||
# Check that id is the key field
|
||||
id_field = next(f for f in index.fields if f.name == "id")
|
||||
assert id_field.key is True
|
||||
|
||||
# Check vector search configuration
|
||||
assert index.vector_search is not None
|
||||
assert len(index.vector_search.profiles) == 1
|
||||
assert index.vector_search.profiles[0].name == "my-vector-config"
|
||||
assert index.vector_search.profiles[0].algorithm_configuration_name == "my-algorithms-config"
|
||||
|
||||
# Check algorithms
|
||||
assert len(index.vector_search.algorithms) == 1
|
||||
assert index.vector_search.algorithms[0].name == "my-algorithms-config"
|
||||
assert "HnswAlgorithmConfiguration" in str(type(index.vector_search.algorithms[0]))
|
||||
|
||||
# With binary compression and float16, we should have compression configuration
|
||||
assert len(index.vector_search.compressions) == 1
|
||||
assert index.vector_search.compressions[0].compression_name == "myCompression"
|
||||
assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
||||
|
||||
def test_insert(azure_ai_search_instance):
|
||||
|
||||
def test_create_col_scalar_compression(mock_clients):
|
||||
"""Test creating a collection with scalar compression."""
|
||||
mock_search_client, mock_index_client, _ = mock_clients
|
||||
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="scalar-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type="scalar"
|
||||
)
|
||||
|
||||
# Verify the index configuration
|
||||
args, _ = mock_index_client.create_or_update_index.call_args
|
||||
index = args[0]
|
||||
|
||||
# Check compression configuration
|
||||
assert len(index.vector_search.compressions) == 1
|
||||
assert index.vector_search.compressions[0].compression_name == "myCompression"
|
||||
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
||||
|
||||
# Check profile references compression
|
||||
assert index.vector_search.profiles[0].compression_name == "myCompression"
|
||||
|
||||
|
||||
def test_create_col_no_compression(mock_clients):
|
||||
"""Test creating a collection with no compression."""
|
||||
mock_search_client, mock_index_client, _ = mock_clients
|
||||
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="no-compression-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type=None
|
||||
)
|
||||
|
||||
# Verify the index configuration
|
||||
args, _ = mock_index_client.create_or_update_index.call_args
|
||||
index = args[0]
|
||||
|
||||
# Check compression configuration - should be empty
|
||||
assert len(index.vector_search.compressions) == 0
|
||||
|
||||
# Check profile doesn't reference compression
|
||||
assert index.vector_search.profiles[0].compression_name is None
|
||||
|
||||
|
||||
# --- Tests for insert method ---
|
||||
|
||||
def test_insert_single(azure_ai_search_instance):
|
||||
"""Test inserting a single vector."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
payloads = [{"user_id": "user1", "run_id": "run1"}]
|
||||
payloads = [{"user_id": "user1", "run_id": "run1", "agent_id": "agent1"}]
|
||||
ids = ["doc1"]
|
||||
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
# Verify upload_documents was called correctly
|
||||
mock_search_client.upload_documents.assert_called_once()
|
||||
args, _ = mock_search_client.upload_documents.call_args
|
||||
documents = args[0]
|
||||
# Update expected_doc to include extra fields from payload.
|
||||
expected_doc = {
|
||||
"id": "doc1",
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
"payload": json.dumps({"user_id": "user1", "run_id": "run1"}),
|
||||
"user_id": "user1",
|
||||
"run_id": "run1"
|
||||
}
|
||||
assert documents[0] == expected_doc
|
||||
|
||||
# Verify document structure
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["id"] == "doc1"
|
||||
assert documents[0]["vector"] == [0.1, 0.2, 0.3]
|
||||
assert documents[0]["payload"] == json.dumps(payloads[0])
|
||||
assert documents[0]["user_id"] == "user1"
|
||||
assert documents[0]["run_id"] == "run1"
|
||||
assert documents[0]["agent_id"] == "agent1"
|
||||
|
||||
def test_search_preFilter(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
# Setup a fake search result returned by the mocked search method.
|
||||
fake_result = {
|
||||
"id": "doc1",
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps({"user_id": "user1"})
|
||||
}
|
||||
# Configure the mock to return an iterator (list) with fake_result.
|
||||
mock_search_client.search.return_value = [fake_result]
|
||||
|
||||
query_vector = [0.1, 0.2, 0.3]
|
||||
results = instance.search(query_vector, limit=1, filters={"user_id": "user1"}, vector_filter_mode="preFilter")
|
||||
|
||||
# Verify that the search method was called with vector_filter_mode="preFilter".
|
||||
mock_search_client.search.assert_called_once()
|
||||
_, called_kwargs = mock_search_client.search.call_args
|
||||
assert called_kwargs.get("vector_filter_mode") == "preFilter"
|
||||
|
||||
# Verify that the output is parsed correctly.
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "doc1"
|
||||
assert results[0].score == 0.95
|
||||
assert results[0].payload == {"user_id": "user1"}
|
||||
|
||||
def test_search_postFilter(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
# Setup a fake search result for postFilter.
|
||||
fake_result = {
|
||||
"id": "doc2",
|
||||
"@search.score": 0.85,
|
||||
"payload": json.dumps({"user_id": "user2"})
|
||||
}
|
||||
mock_search_client.search.return_value = [fake_result]
|
||||
|
||||
query_vector = [0.4, 0.5, 0.6]
|
||||
results = instance.search(query_vector, limit=1, filters={"user_id": "user2"}, vector_filter_mode="postFilter")
|
||||
|
||||
mock_search_client.search.assert_called_once()
|
||||
_, called_kwargs = mock_search_client.search.call_args
|
||||
assert called_kwargs.get("vector_filter_mode") == "postFilter"
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "doc2"
|
||||
assert results[0].score == 0.85
|
||||
assert results[0].payload == {"user_id": "user2"}
|
||||
|
||||
def test_delete(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
vector_id = "doc1"
|
||||
# Set delete_documents to return an iterable with a successful response.
|
||||
mock_search_client.delete_documents.return_value = [{"status": True, "id": vector_id}]
|
||||
instance.delete(vector_id)
|
||||
mock_search_client.delete_documents.assert_called_once_with(documents=[{"id": vector_id}])
|
||||
|
||||
def test_update(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
vector_id = "doc1"
|
||||
new_vector = [0.7, 0.8, 0.9]
|
||||
new_payload = {"user_id": "updated"}
|
||||
# Set merge_or_upload_documents to return an iterable with a successful response.
|
||||
mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": vector_id}]
|
||||
instance.update(vector_id, vector=new_vector, payload=new_payload)
|
||||
mock_search_client.merge_or_upload_documents.assert_called_once()
|
||||
kwargs = mock_search_client.merge_or_upload_documents.call_args.kwargs
|
||||
document = kwargs["documents"][0]
|
||||
assert document["id"] == vector_id
|
||||
assert document["vector"] == new_vector
|
||||
assert document["payload"] == json.dumps(new_payload)
|
||||
# The update method will also add the 'user_id' field.
|
||||
assert document["user_id"] == "updated"
|
||||
|
||||
def test_get(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
fake_result = {
|
||||
"id": "doc1",
|
||||
"payload": json.dumps({"user_id": "user1"})
|
||||
}
|
||||
mock_search_client.get_document.return_value = fake_result
|
||||
result = instance.get("doc1")
|
||||
mock_search_client.get_document.assert_called_once_with(key="doc1")
|
||||
assert result.id == "doc1"
|
||||
assert result.payload == {"user_id": "user1"}
|
||||
assert result.score is None
|
||||
|
||||
def test_list(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
fake_result = {
|
||||
"id": "doc1",
|
||||
"@search.score": 0.99,
|
||||
"payload": json.dumps({"user_id": "user1"})
|
||||
}
|
||||
mock_search_client.search.return_value = [fake_result]
|
||||
# Call list with a simple filter.
|
||||
results = instance.list(filters={"user_id": "user1"}, limit=1)
|
||||
# Verify the search method was called with the proper parameters.
|
||||
expected_filter = instance._build_filter_expression({"user_id": "user1"})
|
||||
mock_search_client.search.assert_called_once_with(
|
||||
search_text="*",
|
||||
filter=expected_filter,
|
||||
top=1
|
||||
)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "doc1"
|
||||
|
||||
# --- New tests for practical end-user scenarios ---
|
||||
|
||||
def test_bulk_insert(azure_ai_search_instance):
|
||||
"""Test inserting a batch of documents (common for initial data loading)."""
|
||||
def test_insert_multiple(azure_ai_search_instance):
|
||||
"""Test inserting multiple vectors in one call."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Create a batch of 10 documents
|
||||
num_docs = 10
|
||||
vectors = [[0.1, 0.2, 0.3] for _ in range(num_docs)]
|
||||
# Create multiple vectors
|
||||
num_docs = 3
|
||||
vectors = [[float(i)/10, float(i+1)/10, float(i+2)/10] for i in range(num_docs)]
|
||||
payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)]
|
||||
ids = [f"doc{i}" for i in range(num_docs)]
|
||||
|
||||
@@ -208,25 +401,35 @@ def test_bulk_insert(azure_ai_search_instance):
|
||||
{"status": True, "id": id_val} for id_val in ids
|
||||
]
|
||||
|
||||
# Insert the batch
|
||||
# Insert the documents
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
# Verify the call
|
||||
# Verify upload_documents was called with correct documents
|
||||
mock_search_client.upload_documents.assert_called_once()
|
||||
args, _ = mock_search_client.upload_documents.call_args
|
||||
documents = args[0]
|
||||
|
||||
# Verify all documents were included
|
||||
assert len(documents) == num_docs
|
||||
|
||||
# Verify the first and last document
|
||||
# Check first document
|
||||
assert documents[0]["id"] == "doc0"
|
||||
assert documents[-1]["id"] == f"doc{num_docs-1}"
|
||||
assert documents[0]["vector"] == [0.0, 0.1, 0.2]
|
||||
assert documents[0]["payload"] == json.dumps(payloads[0])
|
||||
assert documents[0]["user_id"] == "user0"
|
||||
|
||||
# Check last document
|
||||
assert documents[2]["id"] == "doc2"
|
||||
assert documents[2]["vector"] == [0.2, 0.3, 0.4]
|
||||
assert documents[2]["payload"] == json.dumps(payloads[2])
|
||||
assert documents[2]["user_id"] == "user2"
|
||||
|
||||
|
||||
def test_insert_error_handling(azure_ai_search_instance):
|
||||
"""Test how the class handles Azure errors during insertion."""
|
||||
def test_insert_with_error(azure_ai_search_instance):
|
||||
"""Test insert when Azure returns an error for one or more documents."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to return a failure for one document
|
||||
# Configure mock to return an error for one document
|
||||
mock_search_client.upload_documents.return_value = [
|
||||
{"status": False, "id": "doc1", "errorMessage": "Azure error"}
|
||||
]
|
||||
@@ -235,274 +438,105 @@ def test_insert_error_handling(azure_ai_search_instance):
|
||||
payloads = [{"user_id": "user1"}]
|
||||
ids = ["doc1"]
|
||||
|
||||
# Exception should be raised
|
||||
# Insert should raise an exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
assert "Insert failed" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_search_with_complex_filters(azure_ai_search_instance):
|
||||
"""Test searching with multiple filter conditions as a user might need."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
assert "Insert failed for document doc1" in str(exc_info.value)
|
||||
|
||||
# Configure mock response
|
||||
mock_search_client.search.return_value = [
|
||||
{
|
||||
"id": "doc1",
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps({"user_id": "user1", "run_id": "run123", "agent_id": "agent456"})
|
||||
}
|
||||
# Configure mock to return mixed success/failure for multiple documents
|
||||
mock_search_client.upload_documents.return_value = [
|
||||
{"status": True, "id": "doc1"},
|
||||
{"status": False, "id": "doc2", "errorMessage": "Azure error"}
|
||||
]
|
||||
|
||||
# Search with multiple filters (common in multi-tenant or segmented applications)
|
||||
filters = {
|
||||
"user_id": "user1",
|
||||
"run_id": "run123",
|
||||
"agent_id": "agent456"
|
||||
}
|
||||
results = instance.search([0.1, 0.2, 0.3], filters=filters)
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
payloads = [{"user_id": "user1"}, {"user_id": "user2"}]
|
||||
ids = ["doc1", "doc2"]
|
||||
|
||||
# Verify search was called with the correct filter expression
|
||||
mock_search_client.search.assert_called_once()
|
||||
_, kwargs = mock_search_client.search.call_args
|
||||
assert "filter" in kwargs
|
||||
|
||||
# The filter should contain all three conditions
|
||||
filter_expr = kwargs["filter"]
|
||||
assert "user_id eq 'user1'" in filter_expr
|
||||
assert "run_id eq 'run123'" in filter_expr
|
||||
assert "agent_id eq 'agent456'" in filter_expr
|
||||
assert " and " in filter_expr # Conditions should be joined by AND
|
||||
|
||||
|
||||
def test_empty_search_results(azure_ai_search_instance):
|
||||
"""Test behavior when search returns no results (common edge case)."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to return empty results
|
||||
mock_search_client.search.return_value = []
|
||||
|
||||
# Search with a non-matching query
|
||||
results = instance.search([0.9, 0.9, 0.9], limit=5)
|
||||
|
||||
# Verify result handling
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_get_nonexistent_document(azure_ai_search_instance):
|
||||
"""Test behavior when getting a document that doesn't exist (should handle gracefully)."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to raise ResourceNotFoundError
|
||||
mock_search_client.get_document.side_effect = ResourceNotFoundError("Document not found")
|
||||
|
||||
# Get a non-existent document
|
||||
result = instance.get("nonexistent_id")
|
||||
|
||||
# Should return None instead of raising exception
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_azure_service_error(azure_ai_search_instance):
|
||||
"""Test handling of Azure service errors (important for robustness)."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to raise HttpResponseError
|
||||
http_error = HttpResponseError("Azure service is unavailable")
|
||||
mock_search_client.search.side_effect = http_error
|
||||
|
||||
# Attempt to search
|
||||
with pytest.raises(HttpResponseError):
|
||||
instance.search([0.1, 0.2, 0.3])
|
||||
|
||||
# Verify search was attempted
|
||||
mock_search_client.search.assert_called_once()
|
||||
|
||||
|
||||
def test_realistic_workflow(azure_ai_search_instance):
|
||||
"""Test a realistic workflow: insert → search → update → search again."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# 1. Insert a document
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
payload = {"user_id": "user1", "content": "Initial content"}
|
||||
doc_id = "workflow_doc"
|
||||
|
||||
mock_search_client.upload_documents.return_value = [{"status": True, "id": doc_id}]
|
||||
instance.insert([vector], [payload], [doc_id])
|
||||
|
||||
# 2. Search for the document
|
||||
mock_search_client.search.return_value = [
|
||||
{
|
||||
"id": doc_id,
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps(payload)
|
||||
}
|
||||
]
|
||||
results = instance.search(vector, filters={"user_id": "user1"})
|
||||
assert len(results) == 1
|
||||
assert results[0].id == doc_id
|
||||
|
||||
# 3. Update the document
|
||||
updated_payload = {"user_id": "user1", "content": "Updated content"}
|
||||
mock_search_client.merge_or_upload_documents.return_value = [{"status": True, "id": doc_id}]
|
||||
instance.update(doc_id, payload=updated_payload)
|
||||
|
||||
# 4. Search again to get updated document
|
||||
mock_search_client.search.return_value = [
|
||||
{
|
||||
"id": doc_id,
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps(updated_payload)
|
||||
}
|
||||
]
|
||||
results = instance.search(vector, filters={"user_id": "user1"})
|
||||
assert len(results) == 1
|
||||
assert results[0].id == doc_id
|
||||
assert results[0].payload["content"] == "Updated content"
|
||||
|
||||
|
||||
def test_sanitize_special_characters(azure_ai_search_instance):
|
||||
"""Test that special characters in filter values are properly sanitized."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock response
|
||||
mock_search_client.search.return_value = [
|
||||
{
|
||||
"id": "doc1",
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps({"user_id": "user's-data"})
|
||||
}
|
||||
]
|
||||
|
||||
# Search with a filter that has special characters (common in real-world data)
|
||||
filters = {"user_id": "user's-data"}
|
||||
results = instance.search([0.1, 0.2, 0.3], filters=filters)
|
||||
|
||||
# Verify search was called with properly escaped filter
|
||||
mock_search_client.search.assert_called_once()
|
||||
_, kwargs = mock_search_client.search.call_args
|
||||
assert "filter" in kwargs
|
||||
|
||||
# The filter should have properly escaped single quotes
|
||||
filter_expr = kwargs["filter"]
|
||||
assert "user_id eq 'user''s-data'" in filter_expr
|
||||
|
||||
|
||||
def test_list_collections(azure_ai_search_instance):
|
||||
"""Test listing all collections/indexes (for management interfaces)."""
|
||||
instance, _, mock_index_client = azure_ai_search_instance
|
||||
|
||||
# List the collections
|
||||
collections = instance.list_cols()
|
||||
|
||||
# Verify the correct method was called
|
||||
mock_index_client.list_index_names.assert_called_once()
|
||||
|
||||
# Check the result
|
||||
assert collections == ["test-index"]
|
||||
|
||||
|
||||
def test_filter_with_numeric_values(azure_ai_search_instance):
|
||||
"""Test filtering with numeric values (common for faceted search)."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock response
|
||||
mock_search_client.search.return_value = [
|
||||
{
|
||||
"id": "doc1",
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps({"user_id": "user1", "count": 42})
|
||||
}
|
||||
]
|
||||
|
||||
# Search with a numeric filter
|
||||
# Note: In the actual implementation, numeric fields might need to be in the payload
|
||||
filters = {"count": 42}
|
||||
results = instance.search([0.1, 0.2, 0.3], filters=filters)
|
||||
|
||||
# Verify the filter expression
|
||||
mock_search_client.search.assert_called_once()
|
||||
_, kwargs = mock_search_client.search.call_args
|
||||
filter_expr = kwargs["filter"]
|
||||
assert "count eq 42" in filter_expr # No quotes for numbers
|
||||
|
||||
|
||||
def test_error_on_update_nonexistent(azure_ai_search_instance):
|
||||
"""Test behavior when updating a document that doesn't exist."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to return a failure for the update
|
||||
mock_search_client.merge_or_upload_documents.return_value = [
|
||||
{"status": False, "id": "nonexistent", "errorMessage": "Document not found"}
|
||||
]
|
||||
|
||||
# Attempt to update a non-existent document
|
||||
# Insert should raise an exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
instance.update("nonexistent", payload={"new": "data"})
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
assert "Update failed" in str(exc_info.value)
|
||||
assert "Insert failed for document doc2" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_different_compression_types():
|
||||
"""Test creating instances with different compression types (important for performance tuning)."""
|
||||
with patch("mem0.vector_stores.azure_ai_search.SearchClient"), \
|
||||
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"):
|
||||
|
||||
# Test with scalar compression
|
||||
scalar_instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="scalar-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=3,
|
||||
compression_type="scalar",
|
||||
use_float16=False
|
||||
)
|
||||
|
||||
# Test with no compression
|
||||
no_compression_instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="no-compression-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=3,
|
||||
compression_type=None,
|
||||
use_float16=False
|
||||
)
|
||||
|
||||
# No assertions needed - we're just verifying that initialization doesn't fail
|
||||
def test_insert_with_missing_payload_fields(azure_ai_search_instance):
|
||||
"""Test inserting with payloads missing some of the expected fields."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
payloads = [{"content": "Some content without user_id, run_id, or agent_id"}]
|
||||
ids = ["doc1"]
|
||||
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
# Verify upload_documents was called correctly
|
||||
mock_search_client.upload_documents.assert_called_once()
|
||||
args, _ = mock_search_client.upload_documents.call_args
|
||||
documents = args[0]
|
||||
|
||||
# Verify document has payload but not the extra fields
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["id"] == "doc1"
|
||||
assert documents[0]["vector"] == [0.1, 0.2, 0.3]
|
||||
assert documents[0]["payload"] == json.dumps(payloads[0])
|
||||
assert "user_id" not in documents[0]
|
||||
assert "run_id" not in documents[0]
|
||||
assert "agent_id" not in documents[0]
|
||||
|
||||
|
||||
def test_high_dimensional_vectors():
|
||||
"""Test handling of high-dimensional vectors typical in AI embeddings."""
|
||||
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
|
||||
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient"):
|
||||
|
||||
# Configure the mock client
|
||||
mock_search_client = MockSearchClient.return_value
|
||||
mock_search_client.upload_documents = Mock()
|
||||
mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1"}]
|
||||
|
||||
# Create an instance with higher dimensions like those from embedding models
|
||||
high_dim_instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="high-dim-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=1536, # Common for models like OpenAI's embeddings
|
||||
compression_type="binary", # Compression often used with high-dim vectors
|
||||
use_float16=True # Reduced precision often used for memory efficiency
|
||||
)
|
||||
|
||||
# Create a high-dimensional vector (stub with zeros for testing)
|
||||
high_dim_vector = [0.0] * 1536
|
||||
payload = {"user_id": "user1"}
|
||||
doc_id = "high_dim_doc"
|
||||
|
||||
# Insert the document
|
||||
high_dim_instance.insert([high_dim_vector], [payload], [doc_id])
|
||||
|
||||
# Verify the insert was called with the full vector
|
||||
mock_search_client.upload_documents.assert_called_once()
|
||||
args, _ = mock_search_client.upload_documents.call_args
|
||||
documents = args[0]
|
||||
assert len(documents[0]["vector"]) == 1536
|
||||
def test_insert_with_http_error(azure_ai_search_instance):
|
||||
"""Test insert when Azure client throws an HTTP error."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to raise an HttpResponseError
|
||||
mock_search_client.upload_documents.side_effect = HttpResponseError("Azure service error")
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
payloads = [{"user_id": "user1"}]
|
||||
ids = ["doc1"]
|
||||
|
||||
# Insert should propagate the HTTP error
|
||||
with pytest.raises(HttpResponseError) as exc_info:
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
assert "Azure service error" in str(exc_info.value)
|
||||
|
||||
|
||||
# --- Tests for search method ---
|
||||
|
||||
def test_search_basic(azure_ai_search_instance):
|
||||
"""Test basic vector search without filters."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to return search results
|
||||
mock_search_client.search.return_value = [
|
||||
{
|
||||
"id": "doc1",
|
||||
"@search.score": 0.95,
|
||||
"payload": json.dumps({"content": "Test content"})
|
||||
}
|
||||
]
|
||||
|
||||
# Search with a vector
|
||||
query_vector = [0.1, 0.2, 0.3]
|
||||
results = instance.search(query_vector, limit=5)
|
||||
|
||||
# Verify search was called correctly
|
||||
mock_search_client.search.assert_called_once()
|
||||
_, kwargs = mock_search_client.search.call_args
|
||||
|
||||
# Check parameters
|
||||
assert len(kwargs["vector_queries"]) == 1
|
||||
assert kwargs["vector_queries"][0].vector == query_vector
|
||||
assert kwargs["vector_queries"][0].k_nearest_neighbors == 5
|
||||
assert kwargs["vector_queries"][0].fields == "vector"
|
||||
assert kwargs["filter"] is None # No filters
|
||||
assert kwargs["top"] == 5
|
||||
assert kwargs["vector_filter_mode"] == "preFilter" # Default mode
|
||||
|
||||
# Check results
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "doc1"
|
||||
assert results[0].score == 0.95
|
||||
assert results[0].payload == {"content": "Test content"}
|
||||
Reference in New Issue
Block a user