Formatting (#2750)
This commit is contained in:
@@ -13,13 +13,15 @@ from mem0.vector_stores.azure_ai_search import AzureAISearch
|
||||
# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch.
|
||||
@pytest.fixture
|
||||
def mock_clients():
|
||||
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
|
||||
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient, \
|
||||
patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential:
|
||||
with (
|
||||
patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient,
|
||||
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient,
|
||||
patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential,
|
||||
):
|
||||
# Create mocked instances for search and index clients.
|
||||
mock_search_client = MockSearchClient.return_value
|
||||
mock_index_client = MockIndexClient.return_value
|
||||
|
||||
|
||||
# Mock the client._client._config.user_agent_policy.add_user_agent
|
||||
mock_search_client._client = MagicMock()
|
||||
mock_search_client._client._config.user_agent_policy.add_user_agent = Mock()
|
||||
@@ -62,7 +64,7 @@ def azure_ai_search_instance(mock_clients):
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=3,
|
||||
compression_type="binary", # testing binary quantization option
|
||||
use_float16=True
|
||||
use_float16=True,
|
||||
)
|
||||
# Return instance and clients for verification.
|
||||
return instance, mock_search_client, mock_index_client
|
||||
@@ -70,21 +72,18 @@ def azure_ai_search_instance(mock_clients):
|
||||
|
||||
# --- Tests for AzureAISearchConfig ---
|
||||
|
||||
|
||||
def test_config_validation_valid():
|
||||
"""Test valid configurations are accepted."""
|
||||
# Test minimal configuration
|
||||
config = AzureAISearchConfig(
|
||||
service_name="test-service",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768
|
||||
)
|
||||
config = AzureAISearchConfig(service_name="test-service", api_key="test-api-key", embedding_model_dims=768)
|
||||
assert config.collection_name == "mem0" # Default value
|
||||
assert config.service_name == "test-service"
|
||||
assert config.api_key == "test-api-key"
|
||||
assert config.embedding_model_dims == 768
|
||||
assert config.compression_type is None
|
||||
assert config.use_float16 is False
|
||||
|
||||
|
||||
# Test with all optional parameters
|
||||
config = AzureAISearchConfig(
|
||||
collection_name="custom-index",
|
||||
@@ -92,7 +91,7 @@ def test_config_validation_valid():
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=1536,
|
||||
compression_type="scalar",
|
||||
use_float16=True
|
||||
use_float16=True,
|
||||
)
|
||||
assert config.collection_name == "custom-index"
|
||||
assert config.compression_type == "scalar"
|
||||
@@ -106,7 +105,7 @@ def test_config_validation_invalid_compression_type():
|
||||
service_name="test-service",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type="invalid-type" # Not a valid option
|
||||
compression_type="invalid-type", # Not a valid option
|
||||
)
|
||||
assert "Invalid compression_type" in str(exc_info.value)
|
||||
|
||||
@@ -118,7 +117,7 @@ def test_config_validation_deprecated_use_compression():
|
||||
service_name="test-service",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
use_compression=True # Deprecated parameter
|
||||
use_compression=True, # Deprecated parameter
|
||||
)
|
||||
# Fix: Use a partial string match instead of exact match
|
||||
assert "use_compression" in str(exc_info.value)
|
||||
@@ -132,7 +131,7 @@ def test_config_validation_extra_fields():
|
||||
service_name="test-service",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
unknown_parameter="value" # Extra field
|
||||
unknown_parameter="value", # Extra field
|
||||
)
|
||||
assert "Extra fields not allowed" in str(exc_info.value)
|
||||
assert "unknown_parameter" in str(exc_info.value)
|
||||
@@ -140,30 +139,28 @@ def test_config_validation_extra_fields():
|
||||
|
||||
# --- Tests for AzureAISearch initialization ---
|
||||
|
||||
|
||||
def test_initialization(mock_clients):
|
||||
"""Test AzureAISearch initialization with different parameters."""
|
||||
mock_search_client, mock_index_client, mock_azure_key_credential = mock_clients
|
||||
|
||||
|
||||
# Test with minimal parameters
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="test-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768
|
||||
service_name="test-service", collection_name="test-index", api_key="test-api-key", embedding_model_dims=768
|
||||
)
|
||||
|
||||
|
||||
# Verify initialization parameters
|
||||
assert instance.index_name == "test-index"
|
||||
assert instance.collection_name == "test-index"
|
||||
assert instance.embedding_model_dims == 768
|
||||
assert instance.compression_type == "none" # Default when None is passed
|
||||
assert instance.use_float16 is False
|
||||
|
||||
|
||||
# Verify client creation
|
||||
mock_azure_key_credential.assert_called_with("test-api-key")
|
||||
assert "mem0" in mock_search_client._client._config.user_agent_policy.add_user_agent.call_args[0]
|
||||
assert "mem0" in mock_index_client._client._config.user_agent_policy.add_user_agent.call_args[0]
|
||||
|
||||
|
||||
# Verify index creation was called
|
||||
mock_index_client.create_or_update_index.assert_called_once()
|
||||
|
||||
@@ -171,75 +168,75 @@ def test_initialization(mock_clients):
|
||||
def test_initialization_with_compression_types(mock_clients):
|
||||
"""Test initialization with different compression types."""
|
||||
mock_search_client, mock_index_client, _ = mock_clients
|
||||
|
||||
|
||||
# Test with scalar compression
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="scalar-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type="scalar"
|
||||
compression_type="scalar",
|
||||
)
|
||||
assert instance.compression_type == "scalar"
|
||||
|
||||
|
||||
# Capture the index creation call
|
||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||
index = args[0]
|
||||
# Verify scalar compression was configured
|
||||
assert hasattr(index.vector_search, 'compressions')
|
||||
assert hasattr(index.vector_search, "compressions")
|
||||
assert len(index.vector_search.compressions) > 0
|
||||
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
||||
|
||||
|
||||
# Test with binary compression
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="binary-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type="binary"
|
||||
compression_type="binary",
|
||||
)
|
||||
assert instance.compression_type == "binary"
|
||||
|
||||
|
||||
# Capture the index creation call
|
||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||
index = args[0]
|
||||
# Verify binary compression was configured
|
||||
assert hasattr(index.vector_search, 'compressions')
|
||||
assert hasattr(index.vector_search, "compressions")
|
||||
assert len(index.vector_search.compressions) > 0
|
||||
assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
||||
|
||||
|
||||
# Test with no compression
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="no-compression-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type=None
|
||||
compression_type=None,
|
||||
)
|
||||
assert instance.compression_type == "none"
|
||||
|
||||
|
||||
# Capture the index creation call
|
||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||
index = args[0]
|
||||
# Verify no compression was configured
|
||||
assert hasattr(index.vector_search, 'compressions')
|
||||
assert hasattr(index.vector_search, "compressions")
|
||||
assert len(index.vector_search.compressions) == 0
|
||||
|
||||
|
||||
def test_initialization_with_float_precision(mock_clients):
|
||||
"""Test initialization with different float precision settings."""
|
||||
mock_search_client, mock_index_client, _ = mock_clients
|
||||
|
||||
|
||||
# Test with half precision (float16)
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="float16-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
use_float16=True
|
||||
use_float16=True,
|
||||
)
|
||||
assert instance.use_float16 is True
|
||||
|
||||
|
||||
# Capture the index creation call
|
||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||
index = args[0]
|
||||
@@ -247,17 +244,17 @@ def test_initialization_with_float_precision(mock_clients):
|
||||
vector_field = next((f for f in index.fields if f.name == "vector"), None)
|
||||
assert vector_field is not None
|
||||
assert "Edm.Half" in vector_field.type
|
||||
|
||||
|
||||
# Test with full precision (float32)
|
||||
instance = AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="float32-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
use_float16=False
|
||||
use_float16=False,
|
||||
)
|
||||
assert instance.use_float16 is False
|
||||
|
||||
|
||||
# Capture the index creation call
|
||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||
index = args[0]
|
||||
@@ -269,21 +266,22 @@ def test_initialization_with_float_precision(mock_clients):
|
||||
|
||||
# --- Tests for create_col method ---
|
||||
|
||||
|
||||
def test_create_col(azure_ai_search_instance):
|
||||
"""Test the create_col method creates an index with the correct configuration."""
|
||||
instance, _, mock_index_client = azure_ai_search_instance
|
||||
|
||||
|
||||
# create_col is called during initialization, so we check the call that was already made
|
||||
mock_index_client.create_or_update_index.assert_called_once()
|
||||
|
||||
|
||||
# Verify the index configuration
|
||||
args, _ = mock_index_client.create_or_update_index.call_args
|
||||
index = args[0]
|
||||
|
||||
|
||||
# Check basic properties
|
||||
assert index.name == "test-index"
|
||||
assert len(index.fields) == 6 # id, user_id, run_id, agent_id, vector, payload
|
||||
|
||||
|
||||
# Check that required fields are present
|
||||
field_names = [f.name for f in index.fields]
|
||||
assert "id" in field_names
|
||||
@@ -292,22 +290,22 @@ def test_create_col(azure_ai_search_instance):
|
||||
assert "user_id" in field_names
|
||||
assert "run_id" in field_names
|
||||
assert "agent_id" in field_names
|
||||
|
||||
|
||||
# Check that id is the key field
|
||||
id_field = next(f for f in index.fields if f.name == "id")
|
||||
assert id_field.key is True
|
||||
|
||||
|
||||
# Check vector search configuration
|
||||
assert index.vector_search is not None
|
||||
assert len(index.vector_search.profiles) == 1
|
||||
assert index.vector_search.profiles[0].name == "my-vector-config"
|
||||
assert index.vector_search.profiles[0].algorithm_configuration_name == "my-algorithms-config"
|
||||
|
||||
|
||||
# Check algorithms
|
||||
assert len(index.vector_search.algorithms) == 1
|
||||
assert index.vector_search.algorithms[0].name == "my-algorithms-config"
|
||||
assert "HnswAlgorithmConfiguration" in str(type(index.vector_search.algorithms[0]))
|
||||
|
||||
|
||||
# With binary compression and float16, we should have compression configuration
|
||||
assert len(index.vector_search.compressions) == 1
|
||||
assert index.vector_search.compressions[0].compression_name == "myCompression"
|
||||
@@ -317,24 +315,24 @@ def test_create_col(azure_ai_search_instance):
|
||||
def test_create_col_scalar_compression(mock_clients):
|
||||
"""Test creating a collection with scalar compression."""
|
||||
mock_search_client, mock_index_client, _ = mock_clients
|
||||
|
||||
|
||||
AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="scalar-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type="scalar"
|
||||
compression_type="scalar",
|
||||
)
|
||||
|
||||
|
||||
# Verify the index configuration
|
||||
args, _ = mock_index_client.create_or_update_index.call_args
|
||||
index = args[0]
|
||||
|
||||
|
||||
# Check compression configuration
|
||||
assert len(index.vector_search.compressions) == 1
|
||||
assert index.vector_search.compressions[0].compression_name == "myCompression"
|
||||
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
||||
|
||||
|
||||
# Check profile references compression
|
||||
assert index.vector_search.profiles[0].compression_name == "myCompression"
|
||||
|
||||
@@ -342,28 +340,29 @@ def test_create_col_scalar_compression(mock_clients):
|
||||
def test_create_col_no_compression(mock_clients):
|
||||
"""Test creating a collection with no compression."""
|
||||
mock_search_client, mock_index_client, _ = mock_clients
|
||||
|
||||
|
||||
AzureAISearch(
|
||||
service_name="test-service",
|
||||
collection_name="no-compression-index",
|
||||
api_key="test-api-key",
|
||||
embedding_model_dims=768,
|
||||
compression_type=None
|
||||
compression_type=None,
|
||||
)
|
||||
|
||||
|
||||
# Verify the index configuration
|
||||
args, _ = mock_index_client.create_or_update_index.call_args
|
||||
index = args[0]
|
||||
|
||||
|
||||
# Check compression configuration - should be empty
|
||||
assert len(index.vector_search.compressions) == 0
|
||||
|
||||
|
||||
# Check profile doesn't reference compression
|
||||
assert index.vector_search.profiles[0].compression_name is None
|
||||
|
||||
|
||||
# --- Tests for insert method ---
|
||||
|
||||
|
||||
def test_insert_single(azure_ai_search_instance):
|
||||
"""Test inserting a single vector."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
@@ -372,9 +371,7 @@ def test_insert_single(azure_ai_search_instance):
|
||||
ids = ["doc1"]
|
||||
|
||||
# Fix: Include status_code: 201 in mock response
|
||||
mock_search_client.upload_documents.return_value = [
|
||||
{"status": True, "id": "doc1", "status_code": 201}
|
||||
]
|
||||
mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1", "status_code": 201}]
|
||||
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
@@ -396,35 +393,35 @@ def test_insert_single(azure_ai_search_instance):
|
||||
def test_insert_multiple(azure_ai_search_instance):
|
||||
"""Test inserting multiple vectors in one call."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
|
||||
# Create multiple vectors
|
||||
num_docs = 3
|
||||
vectors = [[float(i)/10, float(i+1)/10, float(i+2)/10] for i in range(num_docs)]
|
||||
vectors = [[float(i) / 10, float(i + 1) / 10, float(i + 2) / 10] for i in range(num_docs)]
|
||||
payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)]
|
||||
ids = [f"doc{i}" for i in range(num_docs)]
|
||||
|
||||
|
||||
# Configure mock to return success for all documents (fix: add status_code 201)
|
||||
mock_search_client.upload_documents.return_value = [
|
||||
{"status": True, "id": id_val, "status_code": 201} for id_val in ids
|
||||
]
|
||||
|
||||
|
||||
# Insert the documents
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
|
||||
# Verify upload_documents was called with correct documents
|
||||
mock_search_client.upload_documents.assert_called_once()
|
||||
args, _ = mock_search_client.upload_documents.call_args
|
||||
documents = args[0]
|
||||
|
||||
|
||||
# Verify all documents were included
|
||||
assert len(documents) == num_docs
|
||||
|
||||
|
||||
# Check first document
|
||||
assert documents[0]["id"] == "doc0"
|
||||
assert documents[0]["vector"] == [0.0, 0.1, 0.2]
|
||||
assert documents[0]["payload"] == json.dumps(payloads[0])
|
||||
assert documents[0]["user_id"] == "user0"
|
||||
|
||||
|
||||
# Check last document
|
||||
assert documents[2]["id"] == "doc2"
|
||||
assert documents[2]["vector"] == [0.2, 0.3, 0.4]
|
||||
@@ -437,9 +434,7 @@ def test_insert_with_error(azure_ai_search_instance):
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
# Configure mock to return an error for one document
|
||||
mock_search_client.upload_documents.return_value = [
|
||||
{"status": False, "id": "doc1", "errorMessage": "Azure error"}
|
||||
]
|
||||
mock_search_client.upload_documents.return_value = [{"status": False, "id": "doc1", "errorMessage": "Azure error"}]
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
payloads = [{"user_id": "user1"}]
|
||||
@@ -454,7 +449,7 @@ def test_insert_with_error(azure_ai_search_instance):
|
||||
# Configure mock to return mixed success/failure for multiple documents
|
||||
mock_search_client.upload_documents.return_value = [
|
||||
{"status": True, "id": "doc1"}, # This should not cause failure
|
||||
{"status": False, "id": "doc2", "errorMessage": "Azure error"}
|
||||
{"status": False, "id": "doc2", "errorMessage": "Azure error"},
|
||||
]
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
@@ -465,8 +460,9 @@ def test_insert_with_error(azure_ai_search_instance):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
assert "Insert failed for document doc2" in str(exc_info.value) or \
|
||||
"Insert failed for document doc1" in str(exc_info.value)
|
||||
assert "Insert failed for document doc2" in str(exc_info.value) or "Insert failed for document doc1" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
|
||||
def test_insert_with_missing_payload_fields(azure_ai_search_instance):
|
||||
@@ -500,23 +496,24 @@ def test_insert_with_missing_payload_fields(azure_ai_search_instance):
|
||||
def test_insert_with_http_error(azure_ai_search_instance):
|
||||
"""Test insert when Azure client throws an HTTP error."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
|
||||
|
||||
# Configure mock to raise an HttpResponseError
|
||||
mock_search_client.upload_documents.side_effect = HttpResponseError("Azure service error")
|
||||
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
payloads = [{"user_id": "user1"}]
|
||||
ids = ["doc1"]
|
||||
|
||||
|
||||
# Insert should propagate the HTTP error
|
||||
with pytest.raises(HttpResponseError) as exc_info:
|
||||
instance.insert(vectors, payloads, ids)
|
||||
|
||||
|
||||
assert "Azure service error" in str(exc_info.value)
|
||||
|
||||
|
||||
# --- Tests for search method ---
|
||||
|
||||
|
||||
def test_search_basic(azure_ai_search_instance):
|
||||
"""Test basic vector search without filters."""
|
||||
instance, mock_search_client, _ = azure_ai_search_instance
|
||||
@@ -536,9 +533,7 @@ def test_search_basic(azure_ai_search_instance):
|
||||
# Search with a vector
|
||||
query_text = "test query" # Add a query string
|
||||
query_vector = [0.1, 0.2, 0.3]
|
||||
results = instance.search(
|
||||
query_text, query_vector, limit=5
|
||||
) # Pass the query string
|
||||
results = instance.search(query_text, query_vector, limit=5) # Pass the query string
|
||||
|
||||
# Verify search was called correctly
|
||||
mock_search_client.search.assert_called_once()
|
||||
|
||||
@@ -7,9 +7,7 @@ import dotenv
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`"
|
||||
) from None
|
||||
raise ImportError("Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`") from None
|
||||
|
||||
from mem0.vector_stores.elasticsearch import ElasticsearchDB, OutputData
|
||||
|
||||
@@ -19,20 +17,20 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
def setUpClass(cls):
|
||||
# Load environment variables before any test
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
# Save original environment variables
|
||||
cls.original_env = {
|
||||
'ES_URL': os.getenv('ES_URL', 'http://localhost:9200'),
|
||||
'ES_USERNAME': os.getenv('ES_USERNAME', 'test_user'),
|
||||
'ES_PASSWORD': os.getenv('ES_PASSWORD', 'test_password'),
|
||||
'ES_CLOUD_ID': os.getenv('ES_CLOUD_ID', 'test_cloud_id')
|
||||
"ES_URL": os.getenv("ES_URL", "http://localhost:9200"),
|
||||
"ES_USERNAME": os.getenv("ES_USERNAME", "test_user"),
|
||||
"ES_PASSWORD": os.getenv("ES_PASSWORD", "test_password"),
|
||||
"ES_CLOUD_ID": os.getenv("ES_CLOUD_ID", "test_cloud_id"),
|
||||
}
|
||||
|
||||
|
||||
# Set test environment variables
|
||||
os.environ['ES_URL'] = 'http://localhost'
|
||||
os.environ['ES_USERNAME'] = 'test_user'
|
||||
os.environ['ES_PASSWORD'] = 'test_password'
|
||||
|
||||
os.environ["ES_URL"] = "http://localhost"
|
||||
os.environ["ES_USERNAME"] = "test_user"
|
||||
os.environ["ES_PASSWORD"] = "test_password"
|
||||
|
||||
def setUp(self):
|
||||
# Create a mock Elasticsearch client with proper attributes
|
||||
self.client_mock = MagicMock(spec=Elasticsearch)
|
||||
@@ -41,25 +39,25 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
self.client_mock.indices.create = MagicMock()
|
||||
self.client_mock.indices.delete = MagicMock()
|
||||
self.client_mock.indices.get_alias = MagicMock()
|
||||
|
||||
|
||||
# Start patches BEFORE creating ElasticsearchDB instance
|
||||
patcher = patch('mem0.vector_stores.elasticsearch.Elasticsearch', return_value=self.client_mock)
|
||||
patcher = patch("mem0.vector_stores.elasticsearch.Elasticsearch", return_value=self.client_mock)
|
||||
self.mock_es = patcher.start()
|
||||
self.addCleanup(patcher.stop)
|
||||
|
||||
|
||||
# Initialize ElasticsearchDB with test config and auto_create_index=False
|
||||
self.es_db = ElasticsearchDB(
|
||||
host=os.getenv('ES_URL'),
|
||||
host=os.getenv("ES_URL"),
|
||||
port=9200,
|
||||
collection_name="test_collection",
|
||||
embedding_model_dims=1536,
|
||||
user=os.getenv('ES_USERNAME'),
|
||||
password=os.getenv('ES_PASSWORD'),
|
||||
user=os.getenv("ES_USERNAME"),
|
||||
password=os.getenv("ES_PASSWORD"),
|
||||
verify_certs=False,
|
||||
use_ssl=False,
|
||||
auto_create_index=False # Disable auto creation for tests
|
||||
auto_create_index=False, # Disable auto creation for tests
|
||||
)
|
||||
|
||||
|
||||
# Reset mock counts after initialization
|
||||
self.client_mock.reset_mock()
|
||||
|
||||
@@ -80,15 +78,15 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
# Test when index doesn't exist
|
||||
self.client_mock.indices.exists.return_value = False
|
||||
self.es_db.create_index()
|
||||
|
||||
|
||||
# Verify index creation was called with correct settings
|
||||
self.client_mock.indices.create.assert_called_once()
|
||||
create_args = self.client_mock.indices.create.call_args[1]
|
||||
|
||||
|
||||
# Verify basic index settings
|
||||
self.assertEqual(create_args["index"], "test_collection")
|
||||
self.assertIn("mappings", create_args["body"])
|
||||
|
||||
|
||||
# Verify field mappings
|
||||
mappings = create_args["body"]["mappings"]["properties"]
|
||||
self.assertEqual(mappings["text"]["type"], "text")
|
||||
@@ -97,53 +95,53 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
self.assertEqual(mappings["vector"]["index"], True)
|
||||
self.assertEqual(mappings["vector"]["similarity"], "cosine")
|
||||
self.assertEqual(mappings["metadata"]["type"], "object")
|
||||
|
||||
|
||||
# Reset mocks for next test
|
||||
self.client_mock.reset_mock()
|
||||
|
||||
|
||||
# Test when index already exists
|
||||
self.client_mock.indices.exists.return_value = True
|
||||
self.es_db.create_index()
|
||||
|
||||
|
||||
# Verify create was not called when index exists
|
||||
self.client_mock.indices.create.assert_not_called()
|
||||
|
||||
def test_auto_create_index(self):
|
||||
# Reset mock
|
||||
self.client_mock.reset_mock()
|
||||
|
||||
|
||||
# Test with auto_create_index=True
|
||||
ElasticsearchDB(
|
||||
host=os.getenv('ES_URL'),
|
||||
host=os.getenv("ES_URL"),
|
||||
port=9200,
|
||||
collection_name="test_collection",
|
||||
embedding_model_dims=1536,
|
||||
user=os.getenv('ES_USERNAME'),
|
||||
password=os.getenv('ES_PASSWORD'),
|
||||
user=os.getenv("ES_USERNAME"),
|
||||
password=os.getenv("ES_PASSWORD"),
|
||||
verify_certs=False,
|
||||
use_ssl=False,
|
||||
auto_create_index=True
|
||||
auto_create_index=True,
|
||||
)
|
||||
|
||||
|
||||
# Verify create_index was called during initialization
|
||||
self.client_mock.indices.exists.assert_called_once()
|
||||
|
||||
|
||||
# Reset mock
|
||||
self.client_mock.reset_mock()
|
||||
|
||||
|
||||
# Test with auto_create_index=False
|
||||
ElasticsearchDB(
|
||||
host=os.getenv('ES_URL'),
|
||||
host=os.getenv("ES_URL"),
|
||||
port=9200,
|
||||
collection_name="test_collection",
|
||||
embedding_model_dims=1536,
|
||||
user=os.getenv('ES_USERNAME'),
|
||||
password=os.getenv('ES_PASSWORD'),
|
||||
user=os.getenv("ES_USERNAME"),
|
||||
password=os.getenv("ES_PASSWORD"),
|
||||
verify_certs=False,
|
||||
use_ssl=False,
|
||||
auto_create_index=False
|
||||
auto_create_index=False,
|
||||
)
|
||||
|
||||
|
||||
# Verify create_index was not called during initialization
|
||||
self.client_mock.indices.exists.assert_not_called()
|
||||
|
||||
@@ -152,17 +150,17 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
vectors = [[0.1] * 1536, [0.2] * 1536]
|
||||
payloads = [{"key1": "value1"}, {"key2": "value2"}]
|
||||
ids = ["id1", "id2"]
|
||||
|
||||
|
||||
# Mock bulk operation
|
||||
with patch('mem0.vector_stores.elasticsearch.bulk') as mock_bulk:
|
||||
with patch("mem0.vector_stores.elasticsearch.bulk") as mock_bulk:
|
||||
mock_bulk.return_value = (2, []) # Simulate successful bulk insert
|
||||
|
||||
|
||||
# Perform insert
|
||||
results = self.es_db.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
|
||||
# Verify bulk was called
|
||||
mock_bulk.assert_called_once()
|
||||
|
||||
|
||||
# Verify bulk actions format
|
||||
actions = mock_bulk.call_args[0][1]
|
||||
self.assertEqual(len(actions), 2)
|
||||
@@ -170,7 +168,7 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
self.assertEqual(actions[0]["_id"], "id1")
|
||||
self.assertEqual(actions[0]["_source"]["vector"], vectors[0])
|
||||
self.assertEqual(actions[0]["_source"]["metadata"], payloads[0])
|
||||
|
||||
|
||||
# Verify returned objects
|
||||
self.assertEqual(len(results), 2)
|
||||
self.assertIsInstance(results[0], OutputData)
|
||||
@@ -182,14 +180,7 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
mock_response = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "id1",
|
||||
"_score": 0.8,
|
||||
"_source": {
|
||||
"vector": [0.1] * 1536,
|
||||
"metadata": {"key1": "value1"}
|
||||
}
|
||||
}
|
||||
{"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -206,7 +197,7 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
# Verify search parameters
|
||||
self.assertEqual(search_args["index"], "test_collection")
|
||||
body = search_args["body"]
|
||||
|
||||
|
||||
# Verify KNN query structure
|
||||
self.assertIn("knn", body)
|
||||
self.assertEqual(body["knn"]["field"], "vector")
|
||||
@@ -235,29 +226,24 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
self.es_db.custom_search_query.assert_called_once_with(vectors, limit, filters)
|
||||
|
||||
# Verify custom search query was used
|
||||
self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"})
|
||||
self.client_mock.search.assert_called_once_with(
|
||||
index=self.es_db.collection_name, body={"custom_key": "custom_value"}
|
||||
)
|
||||
|
||||
def test_get(self):
|
||||
# Mock get response with correct structure
|
||||
mock_response = {
|
||||
"_id": "id1",
|
||||
"_source": {
|
||||
"vector": [0.1] * 1536,
|
||||
"metadata": {"key": "value"},
|
||||
"text": "sample text"
|
||||
}
|
||||
"_source": {"vector": [0.1] * 1536, "metadata": {"key": "value"}, "text": "sample text"},
|
||||
}
|
||||
self.client_mock.get.return_value = mock_response
|
||||
|
||||
|
||||
# Perform get
|
||||
result = self.es_db.get(vector_id="id1")
|
||||
|
||||
|
||||
# Verify get call
|
||||
self.client_mock.get.assert_called_once_with(
|
||||
index="test_collection",
|
||||
id="id1"
|
||||
)
|
||||
|
||||
self.client_mock.get.assert_called_once_with(index="test_collection", id="id1")
|
||||
|
||||
# Verify result
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.id, "id1")
|
||||
@@ -267,7 +253,7 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
def test_get_not_found(self):
|
||||
# Mock get raising exception
|
||||
self.client_mock.get.side_effect = Exception("Not found")
|
||||
|
||||
|
||||
# Verify get returns None when document not found
|
||||
result = self.es_db.get(vector_id="nonexistent")
|
||||
self.assertIsNone(result)
|
||||
@@ -277,33 +263,19 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
mock_response = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "id1",
|
||||
"_source": {
|
||||
"vector": [0.1] * 1536,
|
||||
"metadata": {"key1": "value1"}
|
||||
},
|
||||
"_score": 1.0
|
||||
},
|
||||
{
|
||||
"_id": "id2",
|
||||
"_source": {
|
||||
"vector": [0.2] * 1536,
|
||||
"metadata": {"key2": "value2"}
|
||||
},
|
||||
"_score": 0.8
|
||||
}
|
||||
{"_id": "id1", "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}, "_score": 1.0},
|
||||
{"_id": "id2", "_source": {"vector": [0.2] * 1536, "metadata": {"key2": "value2"}}, "_score": 0.8},
|
||||
]
|
||||
}
|
||||
}
|
||||
self.client_mock.search.return_value = mock_response
|
||||
|
||||
|
||||
# Perform list operation
|
||||
results = self.es_db.list(limit=10)
|
||||
|
||||
|
||||
# Verify search call
|
||||
self.client_mock.search.assert_called_once()
|
||||
|
||||
|
||||
# Verify results
|
||||
self.assertEqual(len(results), 1) # Outer list
|
||||
self.assertEqual(len(results[0]), 2) # Inner list
|
||||
@@ -316,30 +288,24 @@ class TestElasticsearchDB(unittest.TestCase):
|
||||
def test_delete(self):
|
||||
# Perform delete
|
||||
self.es_db.delete(vector_id="id1")
|
||||
|
||||
|
||||
# Verify delete call
|
||||
self.client_mock.delete.assert_called_once_with(
|
||||
index="test_collection",
|
||||
id="id1"
|
||||
)
|
||||
self.client_mock.delete.assert_called_once_with(index="test_collection", id="id1")
|
||||
|
||||
def test_list_cols(self):
|
||||
# Mock indices response
|
||||
mock_indices = {"index1": {}, "index2": {}}
|
||||
self.client_mock.indices.get_alias.return_value = mock_indices
|
||||
|
||||
|
||||
# Get collections
|
||||
result = self.es_db.list_cols()
|
||||
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(result, ["index1", "index2"])
|
||||
|
||||
def test_delete_col(self):
|
||||
# Delete collection
|
||||
self.es_db.delete_col()
|
||||
|
||||
|
||||
# Verify delete call
|
||||
self.client_mock.indices.delete.assert_called_once_with(
|
||||
index="test_collection"
|
||||
)
|
||||
|
||||
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
|
||||
|
||||
@@ -21,9 +21,9 @@ def mock_faiss_index():
|
||||
def faiss_instance(mock_faiss_index):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Mock the faiss index creation
|
||||
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index):
|
||||
with patch("faiss.IndexFlatL2", return_value=mock_faiss_index):
|
||||
# Mock the faiss.write_index function
|
||||
with patch('faiss.write_index'):
|
||||
with patch("faiss.write_index"):
|
||||
# Create a FAISS instance with a temporary directory
|
||||
faiss_store = FAISS(
|
||||
collection_name="test_collection",
|
||||
@@ -37,14 +37,14 @@ def faiss_instance(mock_faiss_index):
|
||||
|
||||
def test_create_col(faiss_instance, mock_faiss_index):
|
||||
# Test creating a collection with euclidean distance
|
||||
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2:
|
||||
with patch('faiss.write_index'):
|
||||
with patch("faiss.IndexFlatL2", return_value=mock_faiss_index) as mock_index_flat_l2:
|
||||
with patch("faiss.write_index"):
|
||||
faiss_instance.create_col(name="new_collection")
|
||||
mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims)
|
||||
|
||||
|
||||
# Test creating a collection with inner product distance
|
||||
with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip:
|
||||
with patch('faiss.write_index'):
|
||||
with patch("faiss.IndexFlatIP", return_value=mock_faiss_index) as mock_index_flat_ip:
|
||||
with patch("faiss.write_index"):
|
||||
faiss_instance.create_col(name="new_collection", distance="inner_product")
|
||||
mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims)
|
||||
|
||||
@@ -54,21 +54,21 @@ def test_insert(faiss_instance, mock_faiss_index):
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||
ids = ["id1", "id2"]
|
||||
|
||||
|
||||
# Mock the numpy array conversion
|
||||
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)) as mock_np_array:
|
||||
with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)) as mock_np_array:
|
||||
# Mock index.add
|
||||
mock_faiss_index.add.return_value = None
|
||||
|
||||
|
||||
# Call insert
|
||||
faiss_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
|
||||
# Verify numpy.array was called
|
||||
mock_np_array.assert_called_once_with(vectors, dtype=np.float32)
|
||||
|
||||
|
||||
# Verify index.add was called
|
||||
mock_faiss_index.add.assert_called_once()
|
||||
|
||||
|
||||
# Verify docstore and index_to_id were updated
|
||||
assert faiss_instance.docstore["id1"] == {"name": "vector1"}
|
||||
assert faiss_instance.docstore["id2"] == {"name": "vector2"}
|
||||
@@ -79,39 +79,36 @@ def test_insert(faiss_instance, mock_faiss_index):
|
||||
def test_search(faiss_instance, mock_faiss_index):
|
||||
# Prepare test data
|
||||
query_vector = [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
# Setup the docstore and index_to_id mapping
|
||||
faiss_instance.docstore = {
|
||||
"id1": {"name": "vector1"},
|
||||
"id2": {"name": "vector2"}
|
||||
}
|
||||
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
|
||||
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||
|
||||
|
||||
# First, create the mock for the search return values
|
||||
search_scores = np.array([[0.9, 0.8]])
|
||||
search_indices = np.array([[0, 1]])
|
||||
mock_faiss_index.search.return_value = (search_scores, search_indices)
|
||||
|
||||
|
||||
# Then patch numpy.array only for the query vector conversion
|
||||
with patch('numpy.array') as mock_np_array:
|
||||
with patch("numpy.array") as mock_np_array:
|
||||
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
|
||||
|
||||
|
||||
# Then patch _parse_output to return the expected results
|
||||
expected_results = [
|
||||
OutputData(id="id1", score=0.9, payload={"name": "vector1"}),
|
||||
OutputData(id="id2", score=0.8, payload={"name": "vector2"})
|
||||
OutputData(id="id2", score=0.8, payload={"name": "vector2"}),
|
||||
]
|
||||
|
||||
with patch.object(faiss_instance, '_parse_output', return_value=expected_results):
|
||||
|
||||
with patch.object(faiss_instance, "_parse_output", return_value=expected_results):
|
||||
# Call search
|
||||
results = faiss_instance.search(query="test query", vectors=query_vector, limit=2)
|
||||
|
||||
|
||||
# Verify numpy.array was called (but we don't check exact call arguments since it's complex)
|
||||
assert mock_np_array.called
|
||||
|
||||
|
||||
# Verify index.search was called
|
||||
mock_faiss_index.search.assert_called_once()
|
||||
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 2
|
||||
assert results[0].id == "id1"
|
||||
@@ -125,47 +122,41 @@ def test_search(faiss_instance, mock_faiss_index):
|
||||
def test_search_with_filters(faiss_instance, mock_faiss_index):
|
||||
# Prepare test data
|
||||
query_vector = [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
# Setup the docstore and index_to_id mapping
|
||||
faiss_instance.docstore = {
|
||||
"id1": {"name": "vector1", "category": "A"},
|
||||
"id2": {"name": "vector2", "category": "B"}
|
||||
}
|
||||
faiss_instance.docstore = {"id1": {"name": "vector1", "category": "A"}, "id2": {"name": "vector2", "category": "B"}}
|
||||
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||
|
||||
|
||||
# First set up the search return values
|
||||
search_scores = np.array([[0.9, 0.8]])
|
||||
search_indices = np.array([[0, 1]])
|
||||
mock_faiss_index.search.return_value = (search_scores, search_indices)
|
||||
|
||||
|
||||
# Patch numpy.array for query vector conversion
|
||||
with patch('numpy.array') as mock_np_array:
|
||||
with patch("numpy.array") as mock_np_array:
|
||||
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
|
||||
|
||||
|
||||
# Directly mock the _parse_output method to return our expected values
|
||||
# We're simulating that _parse_output filters to just the first result
|
||||
all_results = [
|
||||
OutputData(id="id1", score=0.9, payload={"name": "vector1", "category": "A"}),
|
||||
OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"})
|
||||
OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"}),
|
||||
]
|
||||
|
||||
# Replace the _apply_filters method to handle our test case
|
||||
with patch.object(faiss_instance, '_parse_output', return_value=all_results):
|
||||
with patch.object(faiss_instance, '_apply_filters', side_effect=lambda p, f: p.get("category") == "A"):
|
||||
with patch.object(faiss_instance, "_parse_output", return_value=all_results):
|
||||
with patch.object(faiss_instance, "_apply_filters", side_effect=lambda p, f: p.get("category") == "A"):
|
||||
# Call search with filters
|
||||
results = faiss_instance.search(
|
||||
query="test query",
|
||||
vectors=query_vector,
|
||||
limit=2,
|
||||
filters={"category": "A"}
|
||||
query="test query", vectors=query_vector, limit=2, filters={"category": "A"}
|
||||
)
|
||||
|
||||
|
||||
# Verify numpy.array was called
|
||||
assert mock_np_array.called
|
||||
|
||||
|
||||
# Verify index.search was called
|
||||
mock_faiss_index.search.assert_called_once()
|
||||
|
||||
|
||||
# Verify filtered results - since we've mocked everything,
|
||||
# we should get just the result we want
|
||||
assert len(results) == 1
|
||||
@@ -176,15 +167,12 @@ def test_search_with_filters(faiss_instance, mock_faiss_index):
|
||||
|
||||
def test_delete(faiss_instance):
|
||||
# Setup the docstore and index_to_id mapping
|
||||
faiss_instance.docstore = {
|
||||
"id1": {"name": "vector1"},
|
||||
"id2": {"name": "vector2"}
|
||||
}
|
||||
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
|
||||
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||
|
||||
|
||||
# Call delete
|
||||
faiss_instance.delete(vector_id="id1")
|
||||
|
||||
|
||||
# Verify the vector was removed from docstore and index_to_id
|
||||
assert "id1" not in faiss_instance.docstore
|
||||
assert 0 not in faiss_instance.index_to_id
|
||||
@@ -194,23 +182,20 @@ def test_delete(faiss_instance):
|
||||
|
||||
def test_update(faiss_instance, mock_faiss_index):
|
||||
# Setup the docstore and index_to_id mapping
|
||||
faiss_instance.docstore = {
|
||||
"id1": {"name": "vector1"},
|
||||
"id2": {"name": "vector2"}
|
||||
}
|
||||
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
|
||||
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||
|
||||
|
||||
# Test updating payload only
|
||||
faiss_instance.update(vector_id="id1", payload={"name": "updated_vector1"})
|
||||
assert faiss_instance.docstore["id1"] == {"name": "updated_vector1"}
|
||||
|
||||
|
||||
# Test updating vector
|
||||
# This requires mocking the delete and insert methods
|
||||
with patch.object(faiss_instance, 'delete') as mock_delete:
|
||||
with patch.object(faiss_instance, 'insert') as mock_insert:
|
||||
with patch.object(faiss_instance, "delete") as mock_delete:
|
||||
with patch.object(faiss_instance, "insert") as mock_insert:
|
||||
new_vector = [0.7, 0.8, 0.9]
|
||||
faiss_instance.update(vector_id="id2", vector=new_vector)
|
||||
|
||||
|
||||
# Verify delete and insert were called
|
||||
# Match the actual call signature (positional arg instead of keyword)
|
||||
mock_delete.assert_called_once_with("id2")
|
||||
@@ -219,17 +204,14 @@ def test_update(faiss_instance, mock_faiss_index):
|
||||
|
||||
def test_get(faiss_instance):
|
||||
# Setup the docstore
|
||||
faiss_instance.docstore = {
|
||||
"id1": {"name": "vector1"},
|
||||
"id2": {"name": "vector2"}
|
||||
}
|
||||
|
||||
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
|
||||
|
||||
# Test getting an existing vector
|
||||
result = faiss_instance.get(vector_id="id1")
|
||||
assert result.id == "id1"
|
||||
assert result.payload == {"name": "vector1"}
|
||||
assert result.score is None
|
||||
|
||||
|
||||
# Test getting a non-existent vector
|
||||
result = faiss_instance.get(vector_id="id3")
|
||||
assert result is None
|
||||
@@ -240,18 +222,18 @@ def test_list(faiss_instance):
|
||||
faiss_instance.docstore = {
|
||||
"id1": {"name": "vector1", "category": "A"},
|
||||
"id2": {"name": "vector2", "category": "B"},
|
||||
"id3": {"name": "vector3", "category": "A"}
|
||||
"id3": {"name": "vector3", "category": "A"},
|
||||
}
|
||||
|
||||
|
||||
# Test listing all vectors
|
||||
results = faiss_instance.list()
|
||||
# Fix the expected result - the list method returns a list of lists
|
||||
assert len(results[0]) == 3
|
||||
|
||||
|
||||
# Test listing with a limit
|
||||
results = faiss_instance.list(limit=2)
|
||||
assert len(results[0]) == 2
|
||||
|
||||
|
||||
# Test listing with filters
|
||||
results = faiss_instance.list(filters={"category": "A"})
|
||||
assert len(results[0]) == 2
|
||||
@@ -263,10 +245,10 @@ def test_col_info(faiss_instance, mock_faiss_index):
|
||||
# Mock index attributes
|
||||
mock_faiss_index.ntotal = 5
|
||||
mock_faiss_index.d = 128
|
||||
|
||||
|
||||
# Get collection info
|
||||
info = faiss_instance.col_info()
|
||||
|
||||
|
||||
# Verify the returned info
|
||||
assert info["name"] == "test_collection"
|
||||
assert info["count"] == 5
|
||||
@@ -276,14 +258,14 @@ def test_col_info(faiss_instance, mock_faiss_index):
|
||||
|
||||
def test_delete_col(faiss_instance):
|
||||
# Mock the os.remove function
|
||||
with patch('os.remove') as mock_remove:
|
||||
with patch('os.path.exists', return_value=True):
|
||||
with patch("os.remove") as mock_remove:
|
||||
with patch("os.path.exists", return_value=True):
|
||||
# Call delete_col
|
||||
faiss_instance.delete_col()
|
||||
|
||||
|
||||
# Verify os.remove was called twice (for index and docstore files)
|
||||
assert mock_remove.call_count == 2
|
||||
|
||||
|
||||
# Verify the internal state was reset
|
||||
assert faiss_instance.index is None
|
||||
assert faiss_instance.docstore == {}
|
||||
@@ -293,17 +275,17 @@ def test_delete_col(faiss_instance):
|
||||
def test_normalize_L2(faiss_instance, mock_faiss_index):
|
||||
# Setup a FAISS instance with normalize_L2=True
|
||||
faiss_instance.normalize_L2 = True
|
||||
|
||||
|
||||
# Prepare test data
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
|
||||
# Mock numpy array conversion
|
||||
# Mock numpy array conversion
|
||||
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)):
|
||||
with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)):
|
||||
# Mock faiss.normalize_L2
|
||||
with patch('faiss.normalize_L2') as mock_normalize:
|
||||
with patch("faiss.normalize_L2") as mock_normalize:
|
||||
# Call insert
|
||||
faiss_instance.insert(vectors=vectors, ids=["id1"])
|
||||
|
||||
|
||||
# Verify faiss.normalize_L2 was called
|
||||
mock_normalize.assert_called_once()
|
||||
|
||||
@@ -11,11 +11,13 @@ def mock_langchain_client():
|
||||
with patch("langchain_community.vectorstores.VectorStore") as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def langchain_instance(mock_langchain_client):
|
||||
mock_client = Mock(spec=VectorStore)
|
||||
return Langchain(client=mock_client, collection_name="test_collection")
|
||||
|
||||
|
||||
def test_insert_vectors(langchain_instance):
|
||||
# Test data
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
@@ -25,48 +27,31 @@ def test_insert_vectors(langchain_instance):
|
||||
# Test with add_embeddings method
|
||||
langchain_instance.client.add_embeddings = Mock()
|
||||
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
langchain_instance.client.add_embeddings.assert_called_once_with(
|
||||
embeddings=vectors,
|
||||
metadatas=payloads,
|
||||
ids=ids
|
||||
)
|
||||
langchain_instance.client.add_embeddings.assert_called_once_with(embeddings=vectors, metadatas=payloads, ids=ids)
|
||||
|
||||
# Test with add_texts method
|
||||
delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely
|
||||
langchain_instance.client.add_texts = Mock()
|
||||
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
langchain_instance.client.add_texts.assert_called_once_with(
|
||||
texts=["text1", "text2"],
|
||||
metadatas=payloads,
|
||||
ids=ids
|
||||
)
|
||||
langchain_instance.client.add_texts.assert_called_once_with(texts=["text1", "text2"], metadatas=payloads, ids=ids)
|
||||
|
||||
# Test with empty payloads
|
||||
langchain_instance.client.add_texts.reset_mock()
|
||||
langchain_instance.insert(vectors=vectors, payloads=None, ids=ids)
|
||||
langchain_instance.client.add_texts.assert_called_once_with(
|
||||
texts=["", ""],
|
||||
metadatas=None,
|
||||
ids=ids
|
||||
)
|
||||
langchain_instance.client.add_texts.assert_called_once_with(texts=["", ""], metadatas=None, ids=ids)
|
||||
|
||||
|
||||
def test_search_vectors(langchain_instance):
|
||||
# Mock search results
|
||||
mock_docs = [
|
||||
Mock(metadata={"name": "vector1"}, id="id1"),
|
||||
Mock(metadata={"name": "vector2"}, id="id2")
|
||||
]
|
||||
mock_docs = [Mock(metadata={"name": "vector1"}, id="id1"), Mock(metadata={"name": "vector2"}, id="id2")]
|
||||
langchain_instance.client.similarity_search_by_vector.return_value = mock_docs
|
||||
|
||||
# Test search without filters
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
results = langchain_instance.search(query="", vectors=vectors, limit=2)
|
||||
|
||||
langchain_instance.client.similarity_search_by_vector.assert_called_once_with(
|
||||
embedding=vectors,
|
||||
k=2
|
||||
)
|
||||
|
||||
|
||||
langchain_instance.client.similarity_search_by_vector.assert_called_once_with(embedding=vectors, k=2)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].id == "id1"
|
||||
assert results[0].payload == {"name": "vector1"}
|
||||
@@ -76,11 +61,8 @@ def test_search_vectors(langchain_instance):
|
||||
# Test search with filters
|
||||
filters = {"name": "vector1"}
|
||||
langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters)
|
||||
langchain_instance.client.similarity_search_by_vector.assert_called_with(
|
||||
embedding=vectors,
|
||||
k=2,
|
||||
filter=filters
|
||||
)
|
||||
langchain_instance.client.similarity_search_by_vector.assert_called_with(embedding=vectors, k=2, filter=filters)
|
||||
|
||||
|
||||
def test_get_vector(langchain_instance):
|
||||
# Mock get result
|
||||
@@ -90,7 +72,7 @@ def test_get_vector(langchain_instance):
|
||||
# Test get existing vector
|
||||
result = langchain_instance.get("id1")
|
||||
langchain_instance.client.get_by_ids.assert_called_once_with(["id1"])
|
||||
|
||||
|
||||
assert result is not None
|
||||
assert result.id == "id1"
|
||||
assert result.payload == {"name": "vector1"}
|
||||
|
||||
@@ -8,9 +8,7 @@ import pytest
|
||||
try:
|
||||
from opensearchpy import AWSV4SignerAuth, OpenSearch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OpenSearch requires extra dependencies. Install with `pip install opensearch-py`"
|
||||
) from None
|
||||
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
|
||||
|
||||
from mem0.vector_stores.opensearch import OpenSearchDB
|
||||
|
||||
@@ -20,13 +18,13 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
def setUpClass(cls):
|
||||
dotenv.load_dotenv()
|
||||
cls.original_env = {
|
||||
'OS_URL': os.getenv('OS_URL', 'http://localhost:9200'),
|
||||
'OS_USERNAME': os.getenv('OS_USERNAME', 'test_user'),
|
||||
'OS_PASSWORD': os.getenv('OS_PASSWORD', 'test_password')
|
||||
"OS_URL": os.getenv("OS_URL", "http://localhost:9200"),
|
||||
"OS_USERNAME": os.getenv("OS_USERNAME", "test_user"),
|
||||
"OS_PASSWORD": os.getenv("OS_PASSWORD", "test_password"),
|
||||
}
|
||||
os.environ['OS_URL'] = 'http://localhost'
|
||||
os.environ['OS_USERNAME'] = 'test_user'
|
||||
os.environ['OS_PASSWORD'] = 'test_password'
|
||||
os.environ["OS_URL"] = "http://localhost"
|
||||
os.environ["OS_USERNAME"] = "test_user"
|
||||
os.environ["OS_PASSWORD"] = "test_password"
|
||||
|
||||
def setUp(self):
|
||||
self.client_mock = MagicMock(spec=OpenSearch)
|
||||
@@ -40,19 +38,19 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
self.client_mock.delete = MagicMock()
|
||||
self.client_mock.search = MagicMock()
|
||||
|
||||
patcher = patch('mem0.vector_stores.opensearch.OpenSearch', return_value=self.client_mock)
|
||||
patcher = patch("mem0.vector_stores.opensearch.OpenSearch", return_value=self.client_mock)
|
||||
self.mock_os = patcher.start()
|
||||
self.addCleanup(patcher.stop)
|
||||
|
||||
self.os_db = OpenSearchDB(
|
||||
host=os.getenv('OS_URL'),
|
||||
host=os.getenv("OS_URL"),
|
||||
port=9200,
|
||||
collection_name="test_collection",
|
||||
embedding_model_dims=1536,
|
||||
user=os.getenv('OS_USERNAME'),
|
||||
password=os.getenv('OS_PASSWORD'),
|
||||
user=os.getenv("OS_USERNAME"),
|
||||
password=os.getenv("OS_PASSWORD"),
|
||||
verify_certs=False,
|
||||
use_ssl=False
|
||||
use_ssl=False,
|
||||
)
|
||||
self.client_mock.reset_mock()
|
||||
|
||||
@@ -86,29 +84,29 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
vectors = [[0.1] * 1536, [0.2] * 1536]
|
||||
payloads = [{"key1": "value1"}, {"key2": "value2"}]
|
||||
ids = ["id1", "id2"]
|
||||
|
||||
|
||||
# Mock the index method
|
||||
self.client_mock.index = MagicMock()
|
||||
|
||||
|
||||
results = self.os_db.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
|
||||
# Verify index was called twice (once for each vector)
|
||||
self.assertEqual(self.client_mock.index.call_count, 2)
|
||||
|
||||
|
||||
# Check first call
|
||||
first_call = self.client_mock.index.call_args_list[0]
|
||||
self.assertEqual(first_call[1]["index"], "test_collection")
|
||||
self.assertEqual(first_call[1]["body"]["vector_field"], vectors[0])
|
||||
self.assertEqual(first_call[1]["body"]["payload"], payloads[0])
|
||||
self.assertEqual(first_call[1]["body"]["id"], ids[0])
|
||||
|
||||
|
||||
# Check second call
|
||||
second_call = self.client_mock.index.call_args_list[1]
|
||||
self.assertEqual(second_call[1]["index"], "test_collection")
|
||||
self.assertEqual(second_call[1]["body"]["vector_field"], vectors[1])
|
||||
self.assertEqual(second_call[1]["body"]["payload"], payloads[1])
|
||||
self.assertEqual(second_call[1]["body"]["id"], ids[1])
|
||||
|
||||
|
||||
# Check results
|
||||
self.assertEqual(len(results), 2)
|
||||
self.assertEqual(results[0].id, "id1")
|
||||
@@ -132,7 +130,7 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
self.client_mock.search.return_value = {"hits": {"hits": []}}
|
||||
result = self.os_db.get("nonexistent")
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
def test_update(self):
|
||||
vector = [0.3] * 1536
|
||||
payload = {"key3": "value3"}
|
||||
@@ -152,7 +150,17 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
self.assertEqual(result, ["test_collection"])
|
||||
|
||||
def test_search(self):
|
||||
mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}}}]}}
|
||||
mock_response = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_id": "id1",
|
||||
"_score": 0.8,
|
||||
"_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
self.client_mock.search.return_value = mock_response
|
||||
vectors = [[0.1] * 1536]
|
||||
results = self.os_db.search(query="", vectors=vectors, limit=5)
|
||||
@@ -179,12 +187,11 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
self.os_db.delete_col()
|
||||
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
|
||||
|
||||
|
||||
def test_init_with_http_auth(self):
|
||||
mock_credentials = MagicMock()
|
||||
mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es")
|
||||
|
||||
with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch:
|
||||
with patch("mem0.vector_stores.opensearch.OpenSearch") as mock_opensearch:
|
||||
OpenSearchDB(
|
||||
host="localhost",
|
||||
port=9200,
|
||||
@@ -192,7 +199,7 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
embedding_model_dims=1536,
|
||||
http_auth=mock_signer,
|
||||
verify_certs=True,
|
||||
use_ssl=True
|
||||
use_ssl=True,
|
||||
)
|
||||
|
||||
# Verify OpenSearch was initialized with correct params
|
||||
@@ -202,5 +209,5 @@ class TestOpenSearchDB(unittest.TestCase):
|
||||
use_ssl=True,
|
||||
verify_certs=True,
|
||||
connection_class=unittest.mock.ANY,
|
||||
pool_maxsize=20
|
||||
)
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ def mock_pinecone_client():
|
||||
client.list_indexes.return_value.names.return_value = []
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pinecone_db(mock_pinecone_client):
|
||||
return PineconeDB(
|
||||
@@ -25,13 +26,14 @@ def pinecone_db(mock_pinecone_client):
|
||||
hybrid_search=False,
|
||||
metric="cosine",
|
||||
batch_size=100,
|
||||
extra_params=None
|
||||
extra_params=None,
|
||||
)
|
||||
|
||||
|
||||
def test_create_col_existing_index(mock_pinecone_client):
|
||||
# Set up the mock before creating the PineconeDB object
|
||||
mock_pinecone_client.list_indexes.return_value.names.return_value = ["test_index"]
|
||||
|
||||
|
||||
pinecone_db = PineconeDB(
|
||||
collection_name="test_index",
|
||||
embedding_model_dims=128,
|
||||
@@ -43,21 +45,23 @@ def test_create_col_existing_index(mock_pinecone_client):
|
||||
hybrid_search=False,
|
||||
metric="cosine",
|
||||
batch_size=100,
|
||||
extra_params=None
|
||||
extra_params=None,
|
||||
)
|
||||
|
||||
|
||||
# Reset the mock to verify it wasn't called during the test
|
||||
mock_pinecone_client.create_index.reset_mock()
|
||||
|
||||
|
||||
pinecone_db.create_col(128, "cosine")
|
||||
|
||||
|
||||
mock_pinecone_client.create_index.assert_not_called()
|
||||
|
||||
|
||||
def test_create_col_new_index(pinecone_db, mock_pinecone_client):
|
||||
mock_pinecone_client.list_indexes.return_value.names.return_value = []
|
||||
pinecone_db.create_col(128, "cosine")
|
||||
mock_pinecone_client.create_index.assert_called()
|
||||
|
||||
|
||||
def test_insert_vectors(pinecone_db):
|
||||
vectors = [[0.1] * 128, [0.2] * 128]
|
||||
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||
@@ -65,56 +69,61 @@ def test_insert_vectors(pinecone_db):
|
||||
pinecone_db.insert(vectors, payloads, ids)
|
||||
pinecone_db.index.upsert.assert_called()
|
||||
|
||||
|
||||
def test_search_vectors(pinecone_db):
|
||||
pinecone_db.index.query.return_value.matches = [{"id": "id1", "score": 0.9, "metadata": {"name": "vector1"}}]
|
||||
results = pinecone_db.search("test query",[0.1] * 128, limit=1)
|
||||
results = pinecone_db.search("test query", [0.1] * 128, limit=1)
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "id1"
|
||||
assert results[0].score == 0.9
|
||||
|
||||
|
||||
def test_update_vector(pinecone_db):
|
||||
pinecone_db.update("id1", vector=[0.5] * 128, payload={"name": "updated"})
|
||||
pinecone_db.index.upsert.assert_called()
|
||||
|
||||
|
||||
def test_get_vector_found(pinecone_db):
|
||||
# Looking at the _parse_output method, it expects a Vector object
|
||||
# or a list of dictionaries, not a dictionary with an 'id' field
|
||||
|
||||
|
||||
# Create a mock Vector object
|
||||
from pinecone.data.dataclasses.vector import Vector
|
||||
mock_vector = Vector(
|
||||
id="id1",
|
||||
values=[0.1] * 128,
|
||||
metadata={"name": "vector1"}
|
||||
)
|
||||
|
||||
|
||||
mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"})
|
||||
|
||||
# Mock the fetch method to return the mock response object
|
||||
mock_response = MagicMock()
|
||||
mock_response.vectors = {"id1": mock_vector}
|
||||
pinecone_db.index.fetch.return_value = mock_response
|
||||
|
||||
|
||||
result = pinecone_db.get("id1")
|
||||
assert result is not None
|
||||
assert result.id == "id1"
|
||||
assert result.payload == {"name": "vector1"}
|
||||
|
||||
|
||||
def test_delete_vector(pinecone_db):
|
||||
pinecone_db.delete("id1")
|
||||
pinecone_db.index.delete.assert_called_with(ids=["id1"])
|
||||
|
||||
|
||||
def test_get_vector_not_found(pinecone_db):
|
||||
pinecone_db.index.fetch.return_value.vectors = {}
|
||||
result = pinecone_db.get("id1")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_list_cols(pinecone_db):
|
||||
pinecone_db.list_cols()
|
||||
pinecone_db.client.list_indexes.assert_called()
|
||||
|
||||
|
||||
def test_delete_col(pinecone_db):
|
||||
pinecone_db.delete_col()
|
||||
pinecone_db.client.delete_index.assert_called_with("test_index")
|
||||
|
||||
|
||||
def test_col_info(pinecone_db):
|
||||
pinecone_db.col_info()
|
||||
pinecone_db.client.describe_index.assert_called_with("test_index")
|
||||
|
||||
@@ -37,7 +37,7 @@ def supabase_instance(mock_vecs_client, mock_collection):
|
||||
index_method=IndexMethod.HNSW,
|
||||
index_measure=IndexMeasure.COSINE,
|
||||
)
|
||||
|
||||
|
||||
# Manually set the collection attribute since we're mocking the initialization
|
||||
instance.collection = mock_collection
|
||||
return instance
|
||||
@@ -46,14 +46,8 @@ def supabase_instance(mock_vecs_client, mock_collection):
|
||||
def test_create_col(supabase_instance, mock_vecs_client, mock_collection):
|
||||
supabase_instance.create_col(1536)
|
||||
|
||||
mock_vecs_client.return_value.get_or_create_collection.assert_called_with(
|
||||
name="test_collection",
|
||||
dimension=1536
|
||||
)
|
||||
mock_collection.create_index.assert_called_with(
|
||||
method="hnsw",
|
||||
measure="cosine_distance"
|
||||
)
|
||||
mock_vecs_client.return_value.get_or_create_collection.assert_called_with(name="test_collection", dimension=1536)
|
||||
mock_collection.create_index.assert_called_with(method="hnsw", measure="cosine_distance")
|
||||
|
||||
|
||||
def test_insert_vectors(supabase_instance, mock_collection):
|
||||
@@ -63,18 +57,12 @@ def test_insert_vectors(supabase_instance, mock_collection):
|
||||
|
||||
supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
expected_records = [
|
||||
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
|
||||
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
|
||||
]
|
||||
expected_records = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})]
|
||||
mock_collection.upsert.assert_called_once_with(expected_records)
|
||||
|
||||
|
||||
def test_search_vectors(supabase_instance, mock_collection):
|
||||
mock_results = [
|
||||
("id1", 0.9, {"name": "vector1"}),
|
||||
("id2", 0.8, {"name": "vector2"})
|
||||
]
|
||||
mock_results = [("id1", 0.9, {"name": "vector1"}), ("id2", 0.8, {"name": "vector2"})]
|
||||
mock_collection.query.return_value = mock_results
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
@@ -82,11 +70,7 @@ def test_search_vectors(supabase_instance, mock_collection):
|
||||
results = supabase_instance.search(query="", vectors=vectors, limit=2, filters=filters)
|
||||
|
||||
mock_collection.query.assert_called_once_with(
|
||||
data=vectors,
|
||||
limit=2,
|
||||
filters={"category": {"$eq": "test"}},
|
||||
include_metadata=True,
|
||||
include_value=True
|
||||
data=vectors, limit=2, filters={"category": {"$eq": "test"}}, include_metadata=True, include_value=True
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
@@ -129,11 +113,8 @@ def test_get_vector(supabase_instance, mock_collection):
|
||||
|
||||
def test_list_vectors(supabase_instance, mock_collection):
|
||||
mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})]
|
||||
mock_fetch_results = [
|
||||
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
|
||||
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
|
||||
]
|
||||
|
||||
mock_fetch_results = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})]
|
||||
|
||||
mock_collection.query.return_value = mock_query_results
|
||||
mock_collection.fetch.return_value = mock_fetch_results
|
||||
|
||||
@@ -153,10 +134,7 @@ def test_col_info(supabase_instance, mock_collection):
|
||||
"name": "test_collection",
|
||||
"count": 100,
|
||||
"dimension": 1536,
|
||||
"index": {
|
||||
"method": "hnsw",
|
||||
"metric": "cosine_distance"
|
||||
}
|
||||
"index": {"method": "hnsw", "metric": "cosine_distance"},
|
||||
}
|
||||
|
||||
|
||||
@@ -168,10 +146,7 @@ def test_preprocess_filters(supabase_instance):
|
||||
# Test multiple filters
|
||||
multi_filter = {"category": "test", "type": "document"}
|
||||
assert supabase_instance._preprocess_filters(multi_filter) == {
|
||||
"$and": [
|
||||
{"category": {"$eq": "test"}},
|
||||
{"type": {"$eq": "document"}}
|
||||
]
|
||||
"$and": [{"category": {"$eq": "test"}}, {"type": {"$eq": "document"}}]
|
||||
}
|
||||
|
||||
# Test None filters
|
||||
|
||||
@@ -29,9 +29,7 @@ def upstash_instance(mock_index):
|
||||
|
||||
@pytest.fixture
|
||||
def upstash_instance_with_embeddings(mock_index):
|
||||
return UpstashVector(
|
||||
client=mock_index.return_value, collection_name="ns", enable_embeddings=True
|
||||
)
|
||||
return UpstashVector(client=mock_index.return_value, collection_name="ns", enable_embeddings=True)
|
||||
|
||||
|
||||
def test_insert_vectors(upstash_instance, mock_index):
|
||||
@@ -52,12 +50,8 @@ def test_insert_vectors(upstash_instance, mock_index):
|
||||
|
||||
def test_search_vectors(upstash_instance, mock_index):
|
||||
mock_result = [
|
||||
QueryResult(
|
||||
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
|
||||
),
|
||||
QueryResult(
|
||||
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None
|
||||
),
|
||||
QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None),
|
||||
QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None),
|
||||
]
|
||||
|
||||
upstash_instance.client.query_many.return_value = [mock_result]
|
||||
@@ -93,9 +87,7 @@ def test_delete_vector(upstash_instance):
|
||||
|
||||
upstash_instance.delete(vector_id=vector_id)
|
||||
|
||||
upstash_instance.client.delete.assert_called_once_with(
|
||||
ids=[vector_id], namespace="ns"
|
||||
)
|
||||
upstash_instance.client.delete.assert_called_once_with(ids=[vector_id], namespace="ns")
|
||||
|
||||
|
||||
def test_update_vector(upstash_instance):
|
||||
@@ -115,18 +107,12 @@ def test_update_vector(upstash_instance):
|
||||
|
||||
|
||||
def test_get_vector(upstash_instance):
|
||||
mock_result = [
|
||||
QueryResult(
|
||||
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
|
||||
)
|
||||
]
|
||||
mock_result = [QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None)]
|
||||
upstash_instance.client.fetch.return_value = mock_result
|
||||
|
||||
result = upstash_instance.get(vector_id="id1")
|
||||
|
||||
upstash_instance.client.fetch.assert_called_once_with(
|
||||
ids=["id1"], namespace="ns", include_metadata=True
|
||||
)
|
||||
upstash_instance.client.fetch.assert_called_once_with(ids=["id1"], namespace="ns", include_metadata=True)
|
||||
|
||||
assert result.id == "id1"
|
||||
assert result.payload == {"name": "vector1"}
|
||||
@@ -134,15 +120,9 @@ def test_get_vector(upstash_instance):
|
||||
|
||||
def test_list_vectors(upstash_instance):
|
||||
mock_result = [
|
||||
QueryResult(
|
||||
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
|
||||
),
|
||||
QueryResult(
|
||||
id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None
|
||||
),
|
||||
QueryResult(
|
||||
id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None
|
||||
),
|
||||
QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None),
|
||||
QueryResult(id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None),
|
||||
QueryResult(id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None),
|
||||
]
|
||||
handler = MagicMock()
|
||||
|
||||
@@ -204,12 +184,8 @@ def test_insert_vectors_with_embeddings(upstash_instance_with_embeddings, mock_i
|
||||
|
||||
def test_search_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index):
|
||||
mock_result = [
|
||||
QueryResult(
|
||||
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"
|
||||
),
|
||||
QueryResult(
|
||||
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"
|
||||
),
|
||||
QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"),
|
||||
QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"),
|
||||
]
|
||||
|
||||
upstash_instance_with_embeddings.client.query.return_value = mock_result
|
||||
@@ -260,9 +236,7 @@ def test_insert_vectors_with_embeddings_missing_data(upstash_instance_with_embed
|
||||
ValueError,
|
||||
match="When embeddings are enabled, all payloads must contain a 'data' field",
|
||||
):
|
||||
upstash_instance_with_embeddings.insert(
|
||||
vectors=vectors, payloads=payloads, ids=ids
|
||||
)
|
||||
upstash_instance_with_embeddings.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
|
||||
def test_update_vector_with_embeddings_missing_data(upstash_instance_with_embeddings):
|
||||
@@ -316,18 +290,12 @@ def test_get_vector_not_found(upstash_instance):
|
||||
|
||||
result = upstash_instance.get(vector_id="nonexistent")
|
||||
|
||||
upstash_instance.client.fetch.assert_called_once_with(
|
||||
ids=["nonexistent"], namespace="ns", include_metadata=True
|
||||
)
|
||||
upstash_instance.client.fetch.assert_called_once_with(ids=["nonexistent"], namespace="ns", include_metadata=True)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_search_vectors_empty_filters(upstash_instance):
|
||||
mock_result = [
|
||||
QueryResult(
|
||||
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
|
||||
)
|
||||
]
|
||||
mock_result = [QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None)]
|
||||
upstash_instance.client.query_many.return_value = [mock_result]
|
||||
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
@@ -14,47 +14,50 @@ from mem0.vector_stores.vertex_ai_vector_search import GoogleMatchingEngine
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vertex_ai():
|
||||
with patch('google.cloud.aiplatform.MatchingEngineIndex') as mock_index, \
|
||||
patch('google.cloud.aiplatform.MatchingEngineIndexEndpoint') as mock_endpoint, \
|
||||
patch('google.cloud.aiplatform.init') as mock_init:
|
||||
with (
|
||||
patch("google.cloud.aiplatform.MatchingEngineIndex") as mock_index,
|
||||
patch("google.cloud.aiplatform.MatchingEngineIndexEndpoint") as mock_endpoint,
|
||||
patch("google.cloud.aiplatform.init") as mock_init,
|
||||
):
|
||||
mock_index_instance = Mock()
|
||||
mock_endpoint_instance = Mock()
|
||||
yield {
|
||||
'index': mock_index_instance,
|
||||
'endpoint': mock_endpoint_instance,
|
||||
'init': mock_init,
|
||||
'index_class': mock_index,
|
||||
'endpoint_class': mock_endpoint
|
||||
"index": mock_index_instance,
|
||||
"endpoint": mock_endpoint_instance,
|
||||
"init": mock_init,
|
||||
"index_class": mock_index,
|
||||
"endpoint_class": mock_endpoint,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
return GoogleMatchingEngineConfig(
|
||||
project_id='test-project',
|
||||
project_number='123456789',
|
||||
region='us-central1',
|
||||
endpoint_id='test-endpoint',
|
||||
index_id='test-index',
|
||||
deployment_index_id='test-deployment',
|
||||
collection_name='test-collection',
|
||||
vector_search_api_endpoint='test.vertexai.goog'
|
||||
project_id="test-project",
|
||||
project_number="123456789",
|
||||
region="us-central1",
|
||||
endpoint_id="test-endpoint",
|
||||
index_id="test-index",
|
||||
deployment_index_id="test-deployment",
|
||||
collection_name="test-collection",
|
||||
vector_search_api_endpoint="test.vertexai.goog",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(config, mock_vertex_ai):
|
||||
mock_vertex_ai['index_class'].return_value = mock_vertex_ai['index']
|
||||
mock_vertex_ai['endpoint_class'].return_value = mock_vertex_ai['endpoint']
|
||||
mock_vertex_ai["index_class"].return_value = mock_vertex_ai["index"]
|
||||
mock_vertex_ai["endpoint_class"].return_value = mock_vertex_ai["endpoint"]
|
||||
return GoogleMatchingEngine(**config.model_dump())
|
||||
|
||||
|
||||
def test_initialization(vector_store, mock_vertex_ai, config):
|
||||
"""Test proper initialization of GoogleMatchingEngine"""
|
||||
mock_vertex_ai['init'].assert_called_once_with(
|
||||
project=config.project_id,
|
||||
location=config.region
|
||||
)
|
||||
mock_vertex_ai["init"].assert_called_once_with(project=config.project_id, location=config.region)
|
||||
|
||||
expected_index_path = f"projects/{config.project_number}/locations/{config.region}/indexes/{config.index_id}"
|
||||
mock_vertex_ai['index_class'].assert_called_once_with(index_name=expected_index_path)
|
||||
mock_vertex_ai["index_class"].assert_called_once_with(index_name=expected_index_path)
|
||||
|
||||
|
||||
def test_insert_vectors(vector_store, mock_vertex_ai):
|
||||
"""Test inserting vectors with payloads"""
|
||||
@@ -64,13 +67,14 @@ def test_insert_vectors(vector_store, mock_vertex_ai):
|
||||
|
||||
vector_store.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
mock_vertex_ai['index'].upsert_datapoints.assert_called_once()
|
||||
call_args = mock_vertex_ai['index'].upsert_datapoints.call_args[1]
|
||||
assert len(call_args['datapoints']) == 1
|
||||
datapoint_str = str(call_args['datapoints'][0])
|
||||
mock_vertex_ai["index"].upsert_datapoints.assert_called_once()
|
||||
call_args = mock_vertex_ai["index"].upsert_datapoints.call_args[1]
|
||||
assert len(call_args["datapoints"]) == 1
|
||||
datapoint_str = str(call_args["datapoints"][0])
|
||||
assert "test-id" in datapoint_str
|
||||
assert "0.1" in datapoint_str and "0.2" in datapoint_str and "0.3" in datapoint_str
|
||||
|
||||
|
||||
def test_search_vectors(vector_store, mock_vertex_ai):
|
||||
"""Test searching vectors with filters"""
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
@@ -85,7 +89,7 @@ def test_search_vectors(vector_store, mock_vertex_ai):
|
||||
mock_restrict.allow_list = ["test_user"]
|
||||
mock_restrict.name = "user_id"
|
||||
mock_restrict.allow_tokens = ["test_user"]
|
||||
|
||||
|
||||
mock_datapoint.restricts = [mock_restrict]
|
||||
|
||||
mock_neighbor = Mock()
|
||||
@@ -94,16 +98,16 @@ def test_search_vectors(vector_store, mock_vertex_ai):
|
||||
mock_neighbor.datapoint = mock_datapoint
|
||||
mock_neighbor.restricts = [mock_restrict]
|
||||
|
||||
mock_vertex_ai['endpoint'].find_neighbors.return_value = [[mock_neighbor]]
|
||||
mock_vertex_ai["endpoint"].find_neighbors.return_value = [[mock_neighbor]]
|
||||
|
||||
results = vector_store.search(query="", vectors=vectors, filters=filters, limit=1)
|
||||
|
||||
mock_vertex_ai['endpoint'].find_neighbors.assert_called_once_with(
|
||||
mock_vertex_ai["endpoint"].find_neighbors.assert_called_once_with(
|
||||
deployed_index_id=vector_store.deployment_index_id,
|
||||
queries=[vectors],
|
||||
num_neighbors=1,
|
||||
filter=[Namespace("user_id", ["test_user"], [])],
|
||||
return_full_datapoint=True
|
||||
return_full_datapoint=True,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
@@ -111,29 +115,27 @@ def test_search_vectors(vector_store, mock_vertex_ai):
|
||||
assert results[0].score == 0.1
|
||||
assert results[0].payload == {"user_id": "test_user"}
|
||||
|
||||
|
||||
def test_delete(vector_store, mock_vertex_ai):
|
||||
"""Test deleting vectors"""
|
||||
vector_id = "test-id"
|
||||
|
||||
remove_mock = Mock()
|
||||
|
||||
with patch.object(GoogleMatchingEngine, 'delete', wraps=vector_store.delete) as delete_spy:
|
||||
with patch.object(vector_store.index, 'remove_datapoints', remove_mock):
|
||||
with patch.object(GoogleMatchingEngine, "delete", wraps=vector_store.delete) as delete_spy:
|
||||
with patch.object(vector_store.index, "remove_datapoints", remove_mock):
|
||||
vector_store.delete(ids=[vector_id])
|
||||
|
||||
delete_spy.assert_called_once_with(ids=[vector_id])
|
||||
remove_mock.assert_called_once_with(datapoint_ids=[vector_id])
|
||||
|
||||
|
||||
def test_error_handling(vector_store, mock_vertex_ai):
|
||||
"""Test error handling during operations"""
|
||||
mock_vertex_ai['index'].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request")
|
||||
mock_vertex_ai["index"].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
vector_store.insert(
|
||||
vectors=[[0.1, 0.2, 0.3]],
|
||||
payloads=[{"name": "test"}],
|
||||
ids=["test-id"]
|
||||
)
|
||||
vector_store.insert(vectors=[[0.1, 0.2, 0.3]], payloads=[{"name": "test"}], ids=["test-id"])
|
||||
|
||||
assert isinstance(exc_info.value, exceptions.InvalidArgument)
|
||||
assert "Invalid request" in str(exc_info.value)
|
||||
|
||||
@@ -76,15 +76,15 @@
|
||||
# self.client_mock.batch = MagicMock()
|
||||
|
||||
# self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
|
||||
# self.client_mock.collections.get.return_value.data.insert_many.return_value = {
|
||||
# "results": [{"id": "id1"}, {"id": "id2"}]
|
||||
# }
|
||||
|
||||
|
||||
# vectors = [[0.1] * 1536, [0.2] * 1536]
|
||||
# payloads = [{"key1": "value1"}, {"key2": "value2"}]
|
||||
# ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
||||
|
||||
|
||||
# results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
# def test_get(self):
|
||||
@@ -108,7 +108,7 @@
|
||||
# result = self.weaviate_db.get(vector_id=valid_uuid)
|
||||
|
||||
# assert result.id == valid_uuid
|
||||
|
||||
|
||||
# expected_payload = mock_response.properties.copy()
|
||||
# expected_payload["id"] = valid_uuid
|
||||
|
||||
@@ -131,10 +131,10 @@
|
||||
# "metadata": {"distance": 0.2}
|
||||
# }
|
||||
# ]
|
||||
|
||||
|
||||
# mock_response = MagicMock()
|
||||
# mock_response.objects = []
|
||||
|
||||
|
||||
# for obj in mock_objects:
|
||||
# mock_obj = MagicMock()
|
||||
# mock_obj.uuid = obj["uuid"]
|
||||
@@ -142,16 +142,16 @@
|
||||
# mock_obj.metadata = MagicMock()
|
||||
# mock_obj.metadata.distance = obj["metadata"]["distance"]
|
||||
# mock_response.objects.append(mock_obj)
|
||||
|
||||
|
||||
# mock_hybrid = MagicMock()
|
||||
# self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
|
||||
# mock_hybrid.return_value = mock_response
|
||||
|
||||
|
||||
# vectors = [[0.1] * 1536]
|
||||
# results = self.weaviate_db.search(query="", vectors=vectors, limit=5)
|
||||
|
||||
|
||||
# mock_hybrid.assert_called_once()
|
||||
|
||||
|
||||
# self.assertEqual(len(results), 1)
|
||||
# self.assertEqual(results[0].id, "id1")
|
||||
# self.assertEqual(results[0].score, 0.8)
|
||||
@@ -163,28 +163,28 @@
|
||||
|
||||
# def test_list(self):
|
||||
# mock_objects = []
|
||||
|
||||
|
||||
# mock_obj1 = MagicMock()
|
||||
# mock_obj1.uuid = "id1"
|
||||
# mock_obj1.properties = {"key1": "value1"}
|
||||
# mock_objects.append(mock_obj1)
|
||||
|
||||
|
||||
# mock_obj2 = MagicMock()
|
||||
# mock_obj2.uuid = "id2"
|
||||
# mock_obj2.properties = {"key2": "value2"}
|
||||
# mock_objects.append(mock_obj2)
|
||||
|
||||
|
||||
# mock_response = MagicMock()
|
||||
# mock_response.objects = mock_objects
|
||||
|
||||
|
||||
# mock_fetch = MagicMock()
|
||||
# self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch
|
||||
# mock_fetch.return_value = mock_response
|
||||
|
||||
|
||||
# results = self.weaviate_db.list(limit=10)
|
||||
|
||||
|
||||
# mock_fetch.assert_called_once()
|
||||
|
||||
|
||||
# # Verify results
|
||||
# self.assertEqual(len(results), 1)
|
||||
# self.assertEqual(len(results[0]), 2)
|
||||
|
||||
Reference in New Issue
Block a user