Formatting (#2750)

This commit is contained in:
Dev Khant
2025-05-22 01:17:29 +05:30
committed by GitHub
parent dff91154a7
commit d85fcda037
71 changed files with 1391 additions and 1823 deletions

View File

@@ -5,13 +5,15 @@ def test_get_update_memory_messages():
retrieved_old_memory_dict = [{"id": "1", "text": "old memory 1"}]
response_content = ["new fact"]
custom_update_memory_prompt = "custom prompt determining memory update"
## When custom update memory prompt is provided
##
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt)
result = prompts.get_update_memory_messages(
retrieved_old_memory_dict, response_content, custom_update_memory_prompt
)
assert result.startswith(custom_update_memory_prompt)
## When custom update memory prompt is not provided
##
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, None)
assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT)
assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT)

View File

@@ -10,9 +10,7 @@ from mem0.embeddings.lmstudio import LMStudioEmbedding
def mock_lm_studio_client():
with patch("mem0.embeddings.lmstudio.OpenAI") as mock_openai:
mock_client = Mock()
mock_client.embeddings.create.return_value = Mock(
data=[Mock(embedding=[0.1, 0.2, 0.3, 0.4, 0.5])]
)
mock_client.embeddings.create.return_value = Mock(data=[Mock(embedding=[0.1, 0.2, 0.3, 0.4, 0.5])])
mock_openai.return_value = mock_client
yield mock_client

View File

@@ -23,7 +23,9 @@ def test_embed_default_model(mock_openai_client):
result = embedder.embed("Hello world")
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536)
mock_openai_client.embeddings.create.assert_called_once_with(
input=["Hello world"], model="text-embedding-3-small", dimensions=1536
)
assert result == [0.1, 0.2, 0.3]
@@ -37,7 +39,7 @@ def test_embed_custom_model(mock_openai_client):
result = embedder.embed("Test embedding")
mock_openai_client.embeddings.create.assert_called_once_with(
input=["Test embedding"], model="text-embedding-2-medium", dimensions = 1024
input=["Test embedding"], model="text-embedding-2-medium", dimensions=1024
)
assert result == [0.4, 0.5, 0.6]
@@ -51,7 +53,9 @@ def test_embed_removes_newlines(mock_openai_client):
result = embedder.embed("Hello\nworld")
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536)
mock_openai_client.embeddings.create.assert_called_once_with(
input=["Hello world"], model="text-embedding-3-small", dimensions=1536
)
assert result == [0.7, 0.8, 0.9]
@@ -65,7 +69,7 @@ def test_embed_without_api_key_env_var(mock_openai_client):
result = embedder.embed("Testing API key")
mock_openai_client.embeddings.create.assert_called_once_with(
input=["Testing API key"], model="text-embedding-3-small", dimensions = 1536
input=["Testing API key"], model="text-embedding-3-small", dimensions=1536
)
assert result == [1.0, 1.1, 1.2]
@@ -81,6 +85,6 @@ def test_embed_uses_environment_api_key(mock_openai_client, monkeypatch):
result = embedder.embed("Environment key test")
mock_openai_client.embeddings.create.assert_called_once_with(
input=["Environment key test"], model="text-embedding-3-small", dimensions = 1536
input=["Environment key test"], model="text-embedding-3-small", dimensions=1536
)
assert result == [1.3, 1.4, 1.5]

View File

@@ -24,11 +24,20 @@ def mock_config():
with patch("mem0.configs.embeddings.base.BaseEmbedderConfig") as mock_config:
mock_config.return_value.vertex_credentials_json = "/path/to/credentials.json"
yield mock_config
@pytest.fixture
def mock_embedding_types():
return ["SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "RETRIEVAL_DOCUMENT", "RETRIEVAL_QUERY", "QUESTION_ANSWERING", "FACT_VERIFICATION", "CODE_RETRIEVAL_QUERY"]
return [
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"RETRIEVAL_DOCUMENT",
"RETRIEVAL_QUERY",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
"CODE_RETRIEVAL_QUERY",
]
@pytest.fixture
@@ -79,30 +88,31 @@ def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_con
assert result == [0.4, 0.5, 0.6]
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_embed_with_memory_action(mock_text_embedding_model, mock_os_environ, mock_config, mock_embedding_types, mock_text_embedding_input):
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_embed_with_memory_action(
mock_text_embedding_model, mock_os_environ, mock_config, mock_embedding_types, mock_text_embedding_input
):
mock_config.return_value.model = "text-embedding-004"
mock_config.return_value.embedding_dims = 256
for embedding_type in mock_embedding_types:
mock_config.return_value.memory_add_embedding_type = embedding_type
mock_config.return_value.memory_update_embedding_type = embedding_type
mock_config.return_value.memory_search_embedding_type = embedding_type
config = mock_config()
embedder = VertexAIEmbedding(config)
mock_text_embedding_model.from_pretrained.assert_called_with("text-embedding-004")
for memory_action in ["add", "update", "search"]:
embedder.embed("Hello world", memory_action=memory_action)
mock_text_embedding_input.assert_called_with(text="Hello world", task_type=embedding_type)
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_with(
texts=[mock_text_embedding_input("Hello world", embedding_type)], output_dimensionality=256
)
@patch("mem0.embeddings.vertexai.os")
def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_config):
@@ -137,15 +147,15 @@ def test_embed_with_different_dimensions(mock_text_embedding_model, mock_os_envi
result = embedder.embed("Large embedding test")
assert result == [0.1] * 1024
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_invalid_memory_action(mock_text_embedding_model, mock_config):
mock_config.return_value.model = "text-embedding-004"
mock_config.return_value.embedding_dims = 256
config = mock_config()
embedder = VertexAIEmbedding(config)
with pytest.raises(ValueError):
embedder.embed("Hello world", memory_action="invalid_action")
embedder.embed("Hello world", memory_action="invalid_action")

View File

@@ -127,4 +127,4 @@ def test_generate_with_http_proxies(default_headers):
api_version=None,
default_headers=default_headers,
)
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")

View File

@@ -31,12 +31,12 @@ def test_deepseek_llm_base_url():
# case3: with config.deepseek_base_url
config_base_url = "https://api.config.com/v1/"
config = BaseLlmConfig(
model="deepseek-chat",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
deepseek_base_url=config_base_url
model="deepseek-chat",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
deepseek_base_url=config_base_url,
)
llm = DeepSeekLLM(config)
assert str(llm.client.base_url) == config_base_url
@@ -99,16 +99,16 @@ def test_generate_response_with_tools(mock_deepseek_client):
response = llm.generate_response(messages, tools=tools)
mock_deepseek_client.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools,
tool_choice="auto"
model="deepseek-chat",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools,
tool_choice="auto",
)
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -10,6 +10,7 @@ try:
from langchain.chat_models.base import BaseChatModel
except ImportError:
from unittest.mock import MagicMock
BaseChatModel = MagicMock
@@ -24,16 +25,11 @@ def mock_langchain_model():
def test_langchain_initialization(mock_langchain_model):
"""Test that LangchainLLM initializes correctly with a valid model."""
# Create a config with the model instance directly
config = BaseLlmConfig(
model=mock_langchain_model,
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
config = BaseLlmConfig(model=mock_langchain_model, temperature=0.7, max_tokens=100, api_key="test-api-key")
# Initialize the LangchainLLM
llm = LangchainLLM(config)
# Verify the model was correctly assigned
assert llm.langchain_model == mock_langchain_model
@@ -41,35 +37,30 @@ def test_langchain_initialization(mock_langchain_model):
def test_generate_response(mock_langchain_model):
"""Test that generate_response correctly processes messages and returns a response."""
# Create a config with the model instance
config = BaseLlmConfig(
model=mock_langchain_model,
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
config = BaseLlmConfig(model=mock_langchain_model, temperature=0.7, max_tokens=100, api_key="test-api-key")
# Initialize the LangchainLLM
llm = LangchainLLM(config)
# Create test messages
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
{"role": "user", "content": "Tell me a joke."}
{"role": "user", "content": "Tell me a joke."},
]
# Get response
response = llm.generate_response(messages)
# Verify the correct message format was passed to the model
expected_langchain_messages = [
("system", "You are a helpful assistant."),
("human", "Hello, how are you?"),
("ai", "I'm doing well! How can I help you?"),
("human", "Tell me a joke.")
("human", "Tell me a joke."),
]
mock_langchain_model.invoke.assert_called_once()
# Extract the first argument of the first call
actual_messages = mock_langchain_model.invoke.call_args[0][0]
@@ -79,25 +70,15 @@ def test_generate_response(mock_langchain_model):
def test_invalid_model():
"""Test that LangchainLLM raises an error with an invalid model."""
config = BaseLlmConfig(
model="not-a-valid-model-instance",
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
config = BaseLlmConfig(model="not-a-valid-model-instance", temperature=0.7, max_tokens=100, api_key="test-api-key")
with pytest.raises(ValueError, match="`model` must be an instance of BaseChatModel"):
LangchainLLM(config)
def test_missing_model():
"""Test that LangchainLLM raises an error when model is None."""
config = BaseLlmConfig(
model=None,
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
config = BaseLlmConfig(model=None, temperature=0.7, max_tokens=100, api_key="test-api-key")
with pytest.raises(ValueError, match="`model` parameter is required"):
LangchainLLM(config)

View File

@@ -11,9 +11,7 @@ def mock_lm_studio_client():
with patch("mem0.llms.lmstudio.OpenAI") as mock_openai: # Corrected path
mock_client = Mock()
mock_client.chat.completions.create.return_value = Mock(
choices=[
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
choices=[Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
)
mock_openai.return_value = mock_client
yield mock_client

View File

@@ -10,18 +10,19 @@ def _setup_mocks(mocker):
"""Helper to setup common mocks for both sync and async fixtures"""
mock_embedder = mocker.MagicMock()
mock_embedder.return_value.embed.return_value = [0.1, 0.2, 0.3]
mocker.patch('mem0.utils.factory.EmbedderFactory.create', mock_embedder)
mocker.patch("mem0.utils.factory.EmbedderFactory.create", mock_embedder)
mock_vector_store = mocker.MagicMock()
mock_vector_store.return_value.search.return_value = []
mocker.patch('mem0.utils.factory.VectorStoreFactory.create',
side_effect=[mock_vector_store.return_value, mocker.MagicMock()])
mocker.patch(
"mem0.utils.factory.VectorStoreFactory.create", side_effect=[mock_vector_store.return_value, mocker.MagicMock()]
)
mock_llm = mocker.MagicMock()
mocker.patch('mem0.utils.factory.LlmFactory.create', mock_llm)
mocker.patch('mem0.memory.storage.SQLiteManager', mocker.MagicMock())
mocker.patch("mem0.utils.factory.LlmFactory.create", mock_llm)
mocker.patch("mem0.memory.storage.SQLiteManager", mocker.MagicMock())
return mock_llm, mock_vector_store
@@ -30,29 +31,26 @@ class TestAddToVectorStoreErrors:
def mock_memory(self, mocker):
"""Fixture that returns a Memory instance with mocker-based mocks"""
mock_llm, _ = _setup_mocks(mocker)
memory = Memory()
memory.config = mocker.MagicMock()
memory.config.custom_fact_extraction_prompt = None
memory.config.custom_update_memory_prompt = None
memory.api_version = "v1.1"
return memory
def test_empty_llm_response_fact_extraction(self, mock_memory, caplog):
"""Test empty response from LLM during fact extraction"""
# Setup
mock_memory.llm.generate_response.return_value = ""
# Execute
with caplog.at_level(logging.ERROR):
result = mock_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
)
# Verify
assert mock_memory.llm.generate_response.call_count == 2
assert result == [] # Should return empty list when no memories processed
@@ -62,20 +60,14 @@ class TestAddToVectorStoreErrors:
"""Test empty response from LLM during memory actions"""
# Setup
# First call returns valid JSON, second call returns empty string
mock_memory.llm.generate_response.side_effect = [
'{"facts": ["test fact"]}',
""
]
mock_memory.llm.generate_response.side_effect = ['{"facts": ["test fact"]}', ""]
# Execute
with caplog.at_level(logging.ERROR):
result = mock_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
)
# Verify
assert mock_memory.llm.generate_response.call_count == 2
assert result == [] # Should return empty list when no memories processed
@@ -88,48 +80,39 @@ class TestAsyncAddToVectorStoreErrors:
def mock_async_memory(self, mocker):
"""Fixture for AsyncMemory with mocker-based mocks"""
mock_llm, _ = _setup_mocks(mocker)
memory = AsyncMemory()
memory.config = mocker.MagicMock()
memory.config.custom_fact_extraction_prompt = None
memory.config.custom_update_memory_prompt = None
memory.api_version = "v1.1"
return memory
@pytest.mark.asyncio
async def test_async_empty_llm_response_fact_extraction(self, mock_async_memory, caplog, mocker):
"""Test empty response in AsyncMemory._add_to_vector_store"""
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock())
mocker.patch("mem0.utils.factory.EmbedderFactory.create", return_value=MagicMock())
mock_async_memory.llm.generate_response.return_value = ""
with caplog.at_level(logging.ERROR):
result = await mock_async_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
)
assert result == []
assert "Error in new_retrieved_facts" in caplog.text
@pytest.mark.asyncio
async def test_async_empty_llm_response_memory_actions(self, mock_async_memory, caplog, mocker):
"""Test empty response in AsyncMemory._add_to_vector_store"""
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock())
mock_async_memory.llm.generate_response.side_effect = [
'{"facts": ["test fact"]}',
""
]
mocker.patch("mem0.utils.factory.EmbedderFactory.create", return_value=MagicMock())
mock_async_memory.llm.generate_response.side_effect = ['{"facts": ["test fact"]}', ""]
with caplog.at_level(logging.ERROR):
result = await mock_async_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
)
assert result == []
assert "Invalid JSON response" in caplog.text

View File

@@ -17,11 +17,13 @@ def mock_openai():
@pytest.fixture
def memory_instance():
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
"mem0.utils.factory.VectorStoreFactory"
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
"mem0.memory.telemetry.capture_event"
), patch("mem0.memory.graph_memory.MemoryGraph"):
with (
patch("mem0.utils.factory.EmbedderFactory") as mock_embedder,
patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store,
patch("mem0.utils.factory.LlmFactory") as mock_llm,
patch("mem0.memory.telemetry.capture_event"),
patch("mem0.memory.graph_memory.MemoryGraph"),
):
mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock()
@@ -30,13 +32,16 @@ def memory_instance():
config.graph_store.config = {"some_config": "value"}
return Memory(config)
@pytest.fixture
def memory_custom_instance():
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
"mem0.utils.factory.VectorStoreFactory"
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
"mem0.memory.telemetry.capture_event"
), patch("mem0.memory.graph_memory.MemoryGraph"):
with (
patch("mem0.utils.factory.EmbedderFactory") as mock_embedder,
patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store,
patch("mem0.utils.factory.LlmFactory") as mock_llm,
patch("mem0.memory.telemetry.capture_event"),
patch("mem0.memory.graph_memory.MemoryGraph"),
):
mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock()
@@ -44,7 +49,7 @@ def memory_custom_instance():
config = MemoryConfig(
version="v1.1",
custom_fact_extraction_prompt="custom prompt extracting memory",
custom_update_memory_prompt="custom prompt determining memory update"
custom_update_memory_prompt="custom prompt determining memory update",
)
config.graph_store.config = {"some_config": "value"}
return Memory(config)
@@ -194,7 +199,6 @@ def test_delete_all(memory_instance, version, enable_graph):
assert result["message"] == "Memories deleted successfully!"
@pytest.mark.parametrize(
"version, enable_graph, expected_result",
[
@@ -242,20 +246,22 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100)
else:
memory_instance.graph.get_all.assert_not_called()
def test_custom_prompts(memory_custom_instance):
messages = [{"role": "user", "content": "Test message"}]
memory_custom_instance.llm.generate_response = Mock()
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
with patch("mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt") as mock_get_update_memory_messages:
with patch(
"mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt"
) as mock_get_update_memory_messages:
memory_custom_instance.add(messages=messages, user_id="test_user")
## custom prompt
##
mock_parse_messages.assert_called_once_with(messages)
memory_custom_instance.llm.generate_response.assert_any_call(
messages=[
{"role": "system", "content": memory_custom_instance.config.custom_fact_extraction_prompt},
@@ -263,12 +269,14 @@ def test_custom_prompts(memory_custom_instance):
],
response_format={"type": "json_object"},
)
## custom update memory prompt
##
mock_get_update_memory_messages.assert_called_once_with([],[],memory_custom_instance.config.custom_update_memory_prompt)
mock_get_update_memory_messages.assert_called_once_with(
[], [], memory_custom_instance.config.custom_update_memory_prompt
)
memory_custom_instance.llm.generate_response.assert_any_call(
messages=[{"role": "user", "content": mock_get_update_memory_messages.return_value}],
response_format={"type": "json_object"},
)
)

View File

@@ -97,4 +97,4 @@ def test_completions_create_with_system_message(mock_memory_client, mock_litellm
call_args = mock_litellm.completion.call_args[1]
assert call_args["messages"][0]["role"] == "system"
assert call_args["messages"][0]["content"] == "You are a helpful assistant."
assert call_args["messages"][0]["content"] == "You are a helpful assistant."

View File

@@ -13,13 +13,15 @@ from mem0.vector_stores.azure_ai_search import AzureAISearch
# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch.
@pytest.fixture
def mock_clients():
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient, \
patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential:
with (
patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient,
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient,
patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential,
):
# Create mocked instances for search and index clients.
mock_search_client = MockSearchClient.return_value
mock_index_client = MockIndexClient.return_value
# Mock the client._client._config.user_agent_policy.add_user_agent
mock_search_client._client = MagicMock()
mock_search_client._client._config.user_agent_policy.add_user_agent = Mock()
@@ -62,7 +64,7 @@ def azure_ai_search_instance(mock_clients):
api_key="test-api-key",
embedding_model_dims=3,
compression_type="binary", # testing binary quantization option
use_float16=True
use_float16=True,
)
# Return instance and clients for verification.
return instance, mock_search_client, mock_index_client
@@ -70,21 +72,18 @@ def azure_ai_search_instance(mock_clients):
# --- Tests for AzureAISearchConfig ---
def test_config_validation_valid():
"""Test valid configurations are accepted."""
# Test minimal configuration
config = AzureAISearchConfig(
service_name="test-service",
api_key="test-api-key",
embedding_model_dims=768
)
config = AzureAISearchConfig(service_name="test-service", api_key="test-api-key", embedding_model_dims=768)
assert config.collection_name == "mem0" # Default value
assert config.service_name == "test-service"
assert config.api_key == "test-api-key"
assert config.embedding_model_dims == 768
assert config.compression_type is None
assert config.use_float16 is False
# Test with all optional parameters
config = AzureAISearchConfig(
collection_name="custom-index",
@@ -92,7 +91,7 @@ def test_config_validation_valid():
api_key="test-api-key",
embedding_model_dims=1536,
compression_type="scalar",
use_float16=True
use_float16=True,
)
assert config.collection_name == "custom-index"
assert config.compression_type == "scalar"
@@ -106,7 +105,7 @@ def test_config_validation_invalid_compression_type():
service_name="test-service",
api_key="test-api-key",
embedding_model_dims=768,
compression_type="invalid-type" # Not a valid option
compression_type="invalid-type", # Not a valid option
)
assert "Invalid compression_type" in str(exc_info.value)
@@ -118,7 +117,7 @@ def test_config_validation_deprecated_use_compression():
service_name="test-service",
api_key="test-api-key",
embedding_model_dims=768,
use_compression=True # Deprecated parameter
use_compression=True, # Deprecated parameter
)
# Fix: Use a partial string match instead of exact match
assert "use_compression" in str(exc_info.value)
@@ -132,7 +131,7 @@ def test_config_validation_extra_fields():
service_name="test-service",
api_key="test-api-key",
embedding_model_dims=768,
unknown_parameter="value" # Extra field
unknown_parameter="value", # Extra field
)
assert "Extra fields not allowed" in str(exc_info.value)
assert "unknown_parameter" in str(exc_info.value)
@@ -140,30 +139,28 @@ def test_config_validation_extra_fields():
# --- Tests for AzureAISearch initialization ---
def test_initialization(mock_clients):
"""Test AzureAISearch initialization with different parameters."""
mock_search_client, mock_index_client, mock_azure_key_credential = mock_clients
# Test with minimal parameters
instance = AzureAISearch(
service_name="test-service",
collection_name="test-index",
api_key="test-api-key",
embedding_model_dims=768
service_name="test-service", collection_name="test-index", api_key="test-api-key", embedding_model_dims=768
)
# Verify initialization parameters
assert instance.index_name == "test-index"
assert instance.collection_name == "test-index"
assert instance.embedding_model_dims == 768
assert instance.compression_type == "none" # Default when None is passed
assert instance.use_float16 is False
# Verify client creation
mock_azure_key_credential.assert_called_with("test-api-key")
assert "mem0" in mock_search_client._client._config.user_agent_policy.add_user_agent.call_args[0]
assert "mem0" in mock_index_client._client._config.user_agent_policy.add_user_agent.call_args[0]
# Verify index creation was called
mock_index_client.create_or_update_index.assert_called_once()
@@ -171,75 +168,75 @@ def test_initialization(mock_clients):
def test_initialization_with_compression_types(mock_clients):
"""Test initialization with different compression types."""
mock_search_client, mock_index_client, _ = mock_clients
# Test with scalar compression
instance = AzureAISearch(
service_name="test-service",
collection_name="scalar-index",
api_key="test-api-key",
embedding_model_dims=768,
compression_type="scalar"
compression_type="scalar",
)
assert instance.compression_type == "scalar"
# Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0]
# Verify scalar compression was configured
assert hasattr(index.vector_search, 'compressions')
assert hasattr(index.vector_search, "compressions")
assert len(index.vector_search.compressions) > 0
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
# Test with binary compression
instance = AzureAISearch(
service_name="test-service",
collection_name="binary-index",
api_key="test-api-key",
embedding_model_dims=768,
compression_type="binary"
compression_type="binary",
)
assert instance.compression_type == "binary"
# Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0]
# Verify binary compression was configured
assert hasattr(index.vector_search, 'compressions')
assert hasattr(index.vector_search, "compressions")
assert len(index.vector_search.compressions) > 0
assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0]))
# Test with no compression
instance = AzureAISearch(
service_name="test-service",
collection_name="no-compression-index",
api_key="test-api-key",
embedding_model_dims=768,
compression_type=None
compression_type=None,
)
assert instance.compression_type == "none"
# Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0]
# Verify no compression was configured
assert hasattr(index.vector_search, 'compressions')
assert hasattr(index.vector_search, "compressions")
assert len(index.vector_search.compressions) == 0
def test_initialization_with_float_precision(mock_clients):
"""Test initialization with different float precision settings."""
mock_search_client, mock_index_client, _ = mock_clients
# Test with half precision (float16)
instance = AzureAISearch(
service_name="test-service",
collection_name="float16-index",
api_key="test-api-key",
embedding_model_dims=768,
use_float16=True
use_float16=True,
)
assert instance.use_float16 is True
# Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0]
@@ -247,17 +244,17 @@ def test_initialization_with_float_precision(mock_clients):
vector_field = next((f for f in index.fields if f.name == "vector"), None)
assert vector_field is not None
assert "Edm.Half" in vector_field.type
# Test with full precision (float32)
instance = AzureAISearch(
service_name="test-service",
collection_name="float32-index",
api_key="test-api-key",
embedding_model_dims=768,
use_float16=False
use_float16=False,
)
assert instance.use_float16 is False
# Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0]
@@ -269,21 +266,22 @@ def test_initialization_with_float_precision(mock_clients):
# --- Tests for create_col method ---
def test_create_col(azure_ai_search_instance):
"""Test the create_col method creates an index with the correct configuration."""
instance, _, mock_index_client = azure_ai_search_instance
# create_col is called during initialization, so we check the call that was already made
mock_index_client.create_or_update_index.assert_called_once()
# Verify the index configuration
args, _ = mock_index_client.create_or_update_index.call_args
index = args[0]
# Check basic properties
assert index.name == "test-index"
assert len(index.fields) == 6 # id, user_id, run_id, agent_id, vector, payload
# Check that required fields are present
field_names = [f.name for f in index.fields]
assert "id" in field_names
@@ -292,22 +290,22 @@ def test_create_col(azure_ai_search_instance):
assert "user_id" in field_names
assert "run_id" in field_names
assert "agent_id" in field_names
# Check that id is the key field
id_field = next(f for f in index.fields if f.name == "id")
assert id_field.key is True
# Check vector search configuration
assert index.vector_search is not None
assert len(index.vector_search.profiles) == 1
assert index.vector_search.profiles[0].name == "my-vector-config"
assert index.vector_search.profiles[0].algorithm_configuration_name == "my-algorithms-config"
# Check algorithms
assert len(index.vector_search.algorithms) == 1
assert index.vector_search.algorithms[0].name == "my-algorithms-config"
assert "HnswAlgorithmConfiguration" in str(type(index.vector_search.algorithms[0]))
# With binary compression and float16, we should have compression configuration
assert len(index.vector_search.compressions) == 1
assert index.vector_search.compressions[0].compression_name == "myCompression"
@@ -317,24 +315,24 @@ def test_create_col(azure_ai_search_instance):
def test_create_col_scalar_compression(mock_clients):
"""Test creating a collection with scalar compression."""
mock_search_client, mock_index_client, _ = mock_clients
AzureAISearch(
service_name="test-service",
collection_name="scalar-index",
api_key="test-api-key",
embedding_model_dims=768,
compression_type="scalar"
compression_type="scalar",
)
# Verify the index configuration
args, _ = mock_index_client.create_or_update_index.call_args
index = args[0]
# Check compression configuration
assert len(index.vector_search.compressions) == 1
assert index.vector_search.compressions[0].compression_name == "myCompression"
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
# Check profile references compression
assert index.vector_search.profiles[0].compression_name == "myCompression"
@@ -342,28 +340,29 @@ def test_create_col_scalar_compression(mock_clients):
def test_create_col_no_compression(mock_clients):
"""Test creating a collection with no compression."""
mock_search_client, mock_index_client, _ = mock_clients
AzureAISearch(
service_name="test-service",
collection_name="no-compression-index",
api_key="test-api-key",
embedding_model_dims=768,
compression_type=None
compression_type=None,
)
# Verify the index configuration
args, _ = mock_index_client.create_or_update_index.call_args
index = args[0]
# Check compression configuration - should be empty
assert len(index.vector_search.compressions) == 0
# Check profile doesn't reference compression
assert index.vector_search.profiles[0].compression_name is None
# --- Tests for insert method ---
def test_insert_single(azure_ai_search_instance):
"""Test inserting a single vector."""
instance, mock_search_client, _ = azure_ai_search_instance
@@ -372,9 +371,7 @@ def test_insert_single(azure_ai_search_instance):
ids = ["doc1"]
# Fix: Include status_code: 201 in mock response
mock_search_client.upload_documents.return_value = [
{"status": True, "id": "doc1", "status_code": 201}
]
mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1", "status_code": 201}]
instance.insert(vectors, payloads, ids)
@@ -396,35 +393,35 @@ def test_insert_single(azure_ai_search_instance):
def test_insert_multiple(azure_ai_search_instance):
"""Test inserting multiple vectors in one call."""
instance, mock_search_client, _ = azure_ai_search_instance
# Create multiple vectors
num_docs = 3
vectors = [[float(i)/10, float(i+1)/10, float(i+2)/10] for i in range(num_docs)]
vectors = [[float(i) / 10, float(i + 1) / 10, float(i + 2) / 10] for i in range(num_docs)]
payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)]
ids = [f"doc{i}" for i in range(num_docs)]
# Configure mock to return success for all documents (fix: add status_code 201)
mock_search_client.upload_documents.return_value = [
{"status": True, "id": id_val, "status_code": 201} for id_val in ids
]
# Insert the documents
instance.insert(vectors, payloads, ids)
# Verify upload_documents was called with correct documents
mock_search_client.upload_documents.assert_called_once()
args, _ = mock_search_client.upload_documents.call_args
documents = args[0]
# Verify all documents were included
assert len(documents) == num_docs
# Check first document
assert documents[0]["id"] == "doc0"
assert documents[0]["vector"] == [0.0, 0.1, 0.2]
assert documents[0]["payload"] == json.dumps(payloads[0])
assert documents[0]["user_id"] == "user0"
# Check last document
assert documents[2]["id"] == "doc2"
assert documents[2]["vector"] == [0.2, 0.3, 0.4]
@@ -437,9 +434,7 @@ def test_insert_with_error(azure_ai_search_instance):
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock to return an error for one document
mock_search_client.upload_documents.return_value = [
{"status": False, "id": "doc1", "errorMessage": "Azure error"}
]
mock_search_client.upload_documents.return_value = [{"status": False, "id": "doc1", "errorMessage": "Azure error"}]
vectors = [[0.1, 0.2, 0.3]]
payloads = [{"user_id": "user1"}]
@@ -454,7 +449,7 @@ def test_insert_with_error(azure_ai_search_instance):
# Configure mock to return mixed success/failure for multiple documents
mock_search_client.upload_documents.return_value = [
{"status": True, "id": "doc1"}, # This should not cause failure
{"status": False, "id": "doc2", "errorMessage": "Azure error"}
{"status": False, "id": "doc2", "errorMessage": "Azure error"},
]
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
@@ -465,8 +460,9 @@ def test_insert_with_error(azure_ai_search_instance):
with pytest.raises(Exception) as exc_info:
instance.insert(vectors, payloads, ids)
assert "Insert failed for document doc2" in str(exc_info.value) or \
"Insert failed for document doc1" in str(exc_info.value)
assert "Insert failed for document doc2" in str(exc_info.value) or "Insert failed for document doc1" in str(
exc_info.value
)
def test_insert_with_missing_payload_fields(azure_ai_search_instance):
@@ -500,23 +496,24 @@ def test_insert_with_missing_payload_fields(azure_ai_search_instance):
def test_insert_with_http_error(azure_ai_search_instance):
"""Test insert when Azure client throws an HTTP error."""
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock to raise an HttpResponseError
mock_search_client.upload_documents.side_effect = HttpResponseError("Azure service error")
vectors = [[0.1, 0.2, 0.3]]
payloads = [{"user_id": "user1"}]
ids = ["doc1"]
# Insert should propagate the HTTP error
with pytest.raises(HttpResponseError) as exc_info:
instance.insert(vectors, payloads, ids)
assert "Azure service error" in str(exc_info.value)
# --- Tests for search method ---
def test_search_basic(azure_ai_search_instance):
"""Test basic vector search without filters."""
instance, mock_search_client, _ = azure_ai_search_instance
@@ -536,9 +533,7 @@ def test_search_basic(azure_ai_search_instance):
# Search with a vector
query_text = "test query" # Add a query string
query_vector = [0.1, 0.2, 0.3]
results = instance.search(
query_text, query_vector, limit=5
) # Pass the query string
results = instance.search(query_text, query_vector, limit=5) # Pass the query string
# Verify search was called correctly
mock_search_client.search.assert_called_once()

View File

@@ -7,9 +7,7 @@ import dotenv
try:
from elasticsearch import Elasticsearch
except ImportError:
raise ImportError(
"Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`"
) from None
raise ImportError("Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`") from None
from mem0.vector_stores.elasticsearch import ElasticsearchDB, OutputData
@@ -19,20 +17,20 @@ class TestElasticsearchDB(unittest.TestCase):
def setUpClass(cls):
# Load environment variables before any test
dotenv.load_dotenv()
# Save original environment variables
cls.original_env = {
'ES_URL': os.getenv('ES_URL', 'http://localhost:9200'),
'ES_USERNAME': os.getenv('ES_USERNAME', 'test_user'),
'ES_PASSWORD': os.getenv('ES_PASSWORD', 'test_password'),
'ES_CLOUD_ID': os.getenv('ES_CLOUD_ID', 'test_cloud_id')
"ES_URL": os.getenv("ES_URL", "http://localhost:9200"),
"ES_USERNAME": os.getenv("ES_USERNAME", "test_user"),
"ES_PASSWORD": os.getenv("ES_PASSWORD", "test_password"),
"ES_CLOUD_ID": os.getenv("ES_CLOUD_ID", "test_cloud_id"),
}
# Set test environment variables
os.environ['ES_URL'] = 'http://localhost'
os.environ['ES_USERNAME'] = 'test_user'
os.environ['ES_PASSWORD'] = 'test_password'
os.environ["ES_URL"] = "http://localhost"
os.environ["ES_USERNAME"] = "test_user"
os.environ["ES_PASSWORD"] = "test_password"
def setUp(self):
# Create a mock Elasticsearch client with proper attributes
self.client_mock = MagicMock(spec=Elasticsearch)
@@ -41,25 +39,25 @@ class TestElasticsearchDB(unittest.TestCase):
self.client_mock.indices.create = MagicMock()
self.client_mock.indices.delete = MagicMock()
self.client_mock.indices.get_alias = MagicMock()
# Start patches BEFORE creating ElasticsearchDB instance
patcher = patch('mem0.vector_stores.elasticsearch.Elasticsearch', return_value=self.client_mock)
patcher = patch("mem0.vector_stores.elasticsearch.Elasticsearch", return_value=self.client_mock)
self.mock_es = patcher.start()
self.addCleanup(patcher.stop)
# Initialize ElasticsearchDB with test config and auto_create_index=False
self.es_db = ElasticsearchDB(
host=os.getenv('ES_URL'),
host=os.getenv("ES_URL"),
port=9200,
collection_name="test_collection",
embedding_model_dims=1536,
user=os.getenv('ES_USERNAME'),
password=os.getenv('ES_PASSWORD'),
user=os.getenv("ES_USERNAME"),
password=os.getenv("ES_PASSWORD"),
verify_certs=False,
use_ssl=False,
auto_create_index=False # Disable auto creation for tests
auto_create_index=False, # Disable auto creation for tests
)
# Reset mock counts after initialization
self.client_mock.reset_mock()
@@ -80,15 +78,15 @@ class TestElasticsearchDB(unittest.TestCase):
# Test when index doesn't exist
self.client_mock.indices.exists.return_value = False
self.es_db.create_index()
# Verify index creation was called with correct settings
self.client_mock.indices.create.assert_called_once()
create_args = self.client_mock.indices.create.call_args[1]
# Verify basic index settings
self.assertEqual(create_args["index"], "test_collection")
self.assertIn("mappings", create_args["body"])
# Verify field mappings
mappings = create_args["body"]["mappings"]["properties"]
self.assertEqual(mappings["text"]["type"], "text")
@@ -97,53 +95,53 @@ class TestElasticsearchDB(unittest.TestCase):
self.assertEqual(mappings["vector"]["index"], True)
self.assertEqual(mappings["vector"]["similarity"], "cosine")
self.assertEqual(mappings["metadata"]["type"], "object")
# Reset mocks for next test
self.client_mock.reset_mock()
# Test when index already exists
self.client_mock.indices.exists.return_value = True
self.es_db.create_index()
# Verify create was not called when index exists
self.client_mock.indices.create.assert_not_called()
def test_auto_create_index(self):
# Reset mock
self.client_mock.reset_mock()
# Test with auto_create_index=True
ElasticsearchDB(
host=os.getenv('ES_URL'),
host=os.getenv("ES_URL"),
port=9200,
collection_name="test_collection",
embedding_model_dims=1536,
user=os.getenv('ES_USERNAME'),
password=os.getenv('ES_PASSWORD'),
user=os.getenv("ES_USERNAME"),
password=os.getenv("ES_PASSWORD"),
verify_certs=False,
use_ssl=False,
auto_create_index=True
auto_create_index=True,
)
# Verify create_index was called during initialization
self.client_mock.indices.exists.assert_called_once()
# Reset mock
self.client_mock.reset_mock()
# Test with auto_create_index=False
ElasticsearchDB(
host=os.getenv('ES_URL'),
host=os.getenv("ES_URL"),
port=9200,
collection_name="test_collection",
embedding_model_dims=1536,
user=os.getenv('ES_USERNAME'),
password=os.getenv('ES_PASSWORD'),
user=os.getenv("ES_USERNAME"),
password=os.getenv("ES_PASSWORD"),
verify_certs=False,
use_ssl=False,
auto_create_index=False
auto_create_index=False,
)
# Verify create_index was not called during initialization
self.client_mock.indices.exists.assert_not_called()
@@ -152,17 +150,17 @@ class TestElasticsearchDB(unittest.TestCase):
vectors = [[0.1] * 1536, [0.2] * 1536]
payloads = [{"key1": "value1"}, {"key2": "value2"}]
ids = ["id1", "id2"]
# Mock bulk operation
with patch('mem0.vector_stores.elasticsearch.bulk') as mock_bulk:
with patch("mem0.vector_stores.elasticsearch.bulk") as mock_bulk:
mock_bulk.return_value = (2, []) # Simulate successful bulk insert
# Perform insert
results = self.es_db.insert(vectors=vectors, payloads=payloads, ids=ids)
# Verify bulk was called
mock_bulk.assert_called_once()
# Verify bulk actions format
actions = mock_bulk.call_args[0][1]
self.assertEqual(len(actions), 2)
@@ -170,7 +168,7 @@ class TestElasticsearchDB(unittest.TestCase):
self.assertEqual(actions[0]["_id"], "id1")
self.assertEqual(actions[0]["_source"]["vector"], vectors[0])
self.assertEqual(actions[0]["_source"]["metadata"], payloads[0])
# Verify returned objects
self.assertEqual(len(results), 2)
self.assertIsInstance(results[0], OutputData)
@@ -182,14 +180,7 @@ class TestElasticsearchDB(unittest.TestCase):
mock_response = {
"hits": {
"hits": [
{
"_id": "id1",
"_score": 0.8,
"_source": {
"vector": [0.1] * 1536,
"metadata": {"key1": "value1"}
}
}
{"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}}
]
}
}
@@ -206,7 +197,7 @@ class TestElasticsearchDB(unittest.TestCase):
# Verify search parameters
self.assertEqual(search_args["index"], "test_collection")
body = search_args["body"]
# Verify KNN query structure
self.assertIn("knn", body)
self.assertEqual(body["knn"]["field"], "vector")
@@ -235,29 +226,24 @@ class TestElasticsearchDB(unittest.TestCase):
self.es_db.custom_search_query.assert_called_once_with(vectors, limit, filters)
# Verify custom search query was used
self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"})
self.client_mock.search.assert_called_once_with(
index=self.es_db.collection_name, body={"custom_key": "custom_value"}
)
def test_get(self):
# Mock get response with correct structure
mock_response = {
"_id": "id1",
"_source": {
"vector": [0.1] * 1536,
"metadata": {"key": "value"},
"text": "sample text"
}
"_source": {"vector": [0.1] * 1536, "metadata": {"key": "value"}, "text": "sample text"},
}
self.client_mock.get.return_value = mock_response
# Perform get
result = self.es_db.get(vector_id="id1")
# Verify get call
self.client_mock.get.assert_called_once_with(
index="test_collection",
id="id1"
)
self.client_mock.get.assert_called_once_with(index="test_collection", id="id1")
# Verify result
self.assertIsNotNone(result)
self.assertEqual(result.id, "id1")
@@ -267,7 +253,7 @@ class TestElasticsearchDB(unittest.TestCase):
def test_get_not_found(self):
# Mock get raising exception
self.client_mock.get.side_effect = Exception("Not found")
# Verify get returns None when document not found
result = self.es_db.get(vector_id="nonexistent")
self.assertIsNone(result)
@@ -277,33 +263,19 @@ class TestElasticsearchDB(unittest.TestCase):
mock_response = {
"hits": {
"hits": [
{
"_id": "id1",
"_source": {
"vector": [0.1] * 1536,
"metadata": {"key1": "value1"}
},
"_score": 1.0
},
{
"_id": "id2",
"_source": {
"vector": [0.2] * 1536,
"metadata": {"key2": "value2"}
},
"_score": 0.8
}
{"_id": "id1", "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}, "_score": 1.0},
{"_id": "id2", "_source": {"vector": [0.2] * 1536, "metadata": {"key2": "value2"}}, "_score": 0.8},
]
}
}
self.client_mock.search.return_value = mock_response
# Perform list operation
results = self.es_db.list(limit=10)
# Verify search call
self.client_mock.search.assert_called_once()
# Verify results
self.assertEqual(len(results), 1) # Outer list
self.assertEqual(len(results[0]), 2) # Inner list
@@ -316,30 +288,24 @@ class TestElasticsearchDB(unittest.TestCase):
def test_delete(self):
# Perform delete
self.es_db.delete(vector_id="id1")
# Verify delete call
self.client_mock.delete.assert_called_once_with(
index="test_collection",
id="id1"
)
self.client_mock.delete.assert_called_once_with(index="test_collection", id="id1")
def test_list_cols(self):
# Mock indices response
mock_indices = {"index1": {}, "index2": {}}
self.client_mock.indices.get_alias.return_value = mock_indices
# Get collections
result = self.es_db.list_cols()
# Verify result
self.assertEqual(result, ["index1", "index2"])
def test_delete_col(self):
# Delete collection
self.es_db.delete_col()
# Verify delete call
self.client_mock.indices.delete.assert_called_once_with(
index="test_collection"
)
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")

View File

@@ -21,9 +21,9 @@ def mock_faiss_index():
def faiss_instance(mock_faiss_index):
with tempfile.TemporaryDirectory() as temp_dir:
# Mock the faiss index creation
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index):
with patch("faiss.IndexFlatL2", return_value=mock_faiss_index):
# Mock the faiss.write_index function
with patch('faiss.write_index'):
with patch("faiss.write_index"):
# Create a FAISS instance with a temporary directory
faiss_store = FAISS(
collection_name="test_collection",
@@ -37,14 +37,14 @@ def faiss_instance(mock_faiss_index):
def test_create_col(faiss_instance, mock_faiss_index):
# Test creating a collection with euclidean distance
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2:
with patch('faiss.write_index'):
with patch("faiss.IndexFlatL2", return_value=mock_faiss_index) as mock_index_flat_l2:
with patch("faiss.write_index"):
faiss_instance.create_col(name="new_collection")
mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims)
# Test creating a collection with inner product distance
with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip:
with patch('faiss.write_index'):
with patch("faiss.IndexFlatIP", return_value=mock_faiss_index) as mock_index_flat_ip:
with patch("faiss.write_index"):
faiss_instance.create_col(name="new_collection", distance="inner_product")
mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims)
@@ -54,21 +54,21 @@ def test_insert(faiss_instance, mock_faiss_index):
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
payloads = [{"name": "vector1"}, {"name": "vector2"}]
ids = ["id1", "id2"]
# Mock the numpy array conversion
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)) as mock_np_array:
with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)) as mock_np_array:
# Mock index.add
mock_faiss_index.add.return_value = None
# Call insert
faiss_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
# Verify numpy.array was called
mock_np_array.assert_called_once_with(vectors, dtype=np.float32)
# Verify index.add was called
mock_faiss_index.add.assert_called_once()
# Verify docstore and index_to_id were updated
assert faiss_instance.docstore["id1"] == {"name": "vector1"}
assert faiss_instance.docstore["id2"] == {"name": "vector2"}
@@ -79,39 +79,36 @@ def test_insert(faiss_instance, mock_faiss_index):
def test_search(faiss_instance, mock_faiss_index):
# Prepare test data
query_vector = [0.1, 0.2, 0.3]
# Setup the docstore and index_to_id mapping
faiss_instance.docstore = {
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# First, create the mock for the search return values
search_scores = np.array([[0.9, 0.8]])
search_indices = np.array([[0, 1]])
mock_faiss_index.search.return_value = (search_scores, search_indices)
# Then patch numpy.array only for the query vector conversion
with patch('numpy.array') as mock_np_array:
with patch("numpy.array") as mock_np_array:
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
# Then patch _parse_output to return the expected results
expected_results = [
OutputData(id="id1", score=0.9, payload={"name": "vector1"}),
OutputData(id="id2", score=0.8, payload={"name": "vector2"})
OutputData(id="id2", score=0.8, payload={"name": "vector2"}),
]
with patch.object(faiss_instance, '_parse_output', return_value=expected_results):
with patch.object(faiss_instance, "_parse_output", return_value=expected_results):
# Call search
results = faiss_instance.search(query="test query", vectors=query_vector, limit=2)
# Verify numpy.array was called (but we don't check exact call arguments since it's complex)
assert mock_np_array.called
# Verify index.search was called
mock_faiss_index.search.assert_called_once()
# Verify results
assert len(results) == 2
assert results[0].id == "id1"
@@ -125,47 +122,41 @@ def test_search(faiss_instance, mock_faiss_index):
def test_search_with_filters(faiss_instance, mock_faiss_index):
# Prepare test data
query_vector = [0.1, 0.2, 0.3]
# Setup the docstore and index_to_id mapping
faiss_instance.docstore = {
"id1": {"name": "vector1", "category": "A"},
"id2": {"name": "vector2", "category": "B"}
}
faiss_instance.docstore = {"id1": {"name": "vector1", "category": "A"}, "id2": {"name": "vector2", "category": "B"}}
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# First set up the search return values
search_scores = np.array([[0.9, 0.8]])
search_indices = np.array([[0, 1]])
mock_faiss_index.search.return_value = (search_scores, search_indices)
# Patch numpy.array for query vector conversion
with patch('numpy.array') as mock_np_array:
with patch("numpy.array") as mock_np_array:
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
# Directly mock the _parse_output method to return our expected values
# We're simulating that _parse_output filters to just the first result
all_results = [
OutputData(id="id1", score=0.9, payload={"name": "vector1", "category": "A"}),
OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"})
OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"}),
]
# Replace the _apply_filters method to handle our test case
with patch.object(faiss_instance, '_parse_output', return_value=all_results):
with patch.object(faiss_instance, '_apply_filters', side_effect=lambda p, f: p.get("category") == "A"):
with patch.object(faiss_instance, "_parse_output", return_value=all_results):
with patch.object(faiss_instance, "_apply_filters", side_effect=lambda p, f: p.get("category") == "A"):
# Call search with filters
results = faiss_instance.search(
query="test query",
vectors=query_vector,
limit=2,
filters={"category": "A"}
query="test query", vectors=query_vector, limit=2, filters={"category": "A"}
)
# Verify numpy.array was called
assert mock_np_array.called
# Verify index.search was called
mock_faiss_index.search.assert_called_once()
# Verify filtered results - since we've mocked everything,
# we should get just the result we want
assert len(results) == 1
@@ -176,15 +167,12 @@ def test_search_with_filters(faiss_instance, mock_faiss_index):
def test_delete(faiss_instance):
# Setup the docstore and index_to_id mapping
faiss_instance.docstore = {
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# Call delete
faiss_instance.delete(vector_id="id1")
# Verify the vector was removed from docstore and index_to_id
assert "id1" not in faiss_instance.docstore
assert 0 not in faiss_instance.index_to_id
@@ -194,23 +182,20 @@ def test_delete(faiss_instance):
def test_update(faiss_instance, mock_faiss_index):
# Setup the docstore and index_to_id mapping
faiss_instance.docstore = {
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# Test updating payload only
faiss_instance.update(vector_id="id1", payload={"name": "updated_vector1"})
assert faiss_instance.docstore["id1"] == {"name": "updated_vector1"}
# Test updating vector
# This requires mocking the delete and insert methods
with patch.object(faiss_instance, 'delete') as mock_delete:
with patch.object(faiss_instance, 'insert') as mock_insert:
with patch.object(faiss_instance, "delete") as mock_delete:
with patch.object(faiss_instance, "insert") as mock_insert:
new_vector = [0.7, 0.8, 0.9]
faiss_instance.update(vector_id="id2", vector=new_vector)
# Verify delete and insert were called
# Match the actual call signature (positional arg instead of keyword)
mock_delete.assert_called_once_with("id2")
@@ -219,17 +204,14 @@ def test_update(faiss_instance, mock_faiss_index):
def test_get(faiss_instance):
# Setup the docstore
faiss_instance.docstore = {
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
# Test getting an existing vector
result = faiss_instance.get(vector_id="id1")
assert result.id == "id1"
assert result.payload == {"name": "vector1"}
assert result.score is None
# Test getting a non-existent vector
result = faiss_instance.get(vector_id="id3")
assert result is None
@@ -240,18 +222,18 @@ def test_list(faiss_instance):
faiss_instance.docstore = {
"id1": {"name": "vector1", "category": "A"},
"id2": {"name": "vector2", "category": "B"},
"id3": {"name": "vector3", "category": "A"}
"id3": {"name": "vector3", "category": "A"},
}
# Test listing all vectors
results = faiss_instance.list()
# Fix the expected result - the list method returns a list of lists
assert len(results[0]) == 3
# Test listing with a limit
results = faiss_instance.list(limit=2)
assert len(results[0]) == 2
# Test listing with filters
results = faiss_instance.list(filters={"category": "A"})
assert len(results[0]) == 2
@@ -263,10 +245,10 @@ def test_col_info(faiss_instance, mock_faiss_index):
# Mock index attributes
mock_faiss_index.ntotal = 5
mock_faiss_index.d = 128
# Get collection info
info = faiss_instance.col_info()
# Verify the returned info
assert info["name"] == "test_collection"
assert info["count"] == 5
@@ -276,14 +258,14 @@ def test_col_info(faiss_instance, mock_faiss_index):
def test_delete_col(faiss_instance):
# Mock the os.remove function
with patch('os.remove') as mock_remove:
with patch('os.path.exists', return_value=True):
with patch("os.remove") as mock_remove:
with patch("os.path.exists", return_value=True):
# Call delete_col
faiss_instance.delete_col()
# Verify os.remove was called twice (for index and docstore files)
assert mock_remove.call_count == 2
# Verify the internal state was reset
assert faiss_instance.index is None
assert faiss_instance.docstore == {}
@@ -293,17 +275,17 @@ def test_delete_col(faiss_instance):
def test_normalize_L2(faiss_instance, mock_faiss_index):
# Setup a FAISS instance with normalize_L2=True
faiss_instance.normalize_L2 = True
# Prepare test data
vectors = [[0.1, 0.2, 0.3]]
# Mock numpy array conversion
# Mock numpy array conversion
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)):
with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)):
# Mock faiss.normalize_L2
with patch('faiss.normalize_L2') as mock_normalize:
with patch("faiss.normalize_L2") as mock_normalize:
# Call insert
faiss_instance.insert(vectors=vectors, ids=["id1"])
# Verify faiss.normalize_L2 was called
mock_normalize.assert_called_once()

View File

@@ -11,11 +11,13 @@ def mock_langchain_client():
with patch("langchain_community.vectorstores.VectorStore") as mock_client:
yield mock_client
@pytest.fixture
def langchain_instance(mock_langchain_client):
mock_client = Mock(spec=VectorStore)
return Langchain(client=mock_client, collection_name="test_collection")
def test_insert_vectors(langchain_instance):
# Test data
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
@@ -25,48 +27,31 @@ def test_insert_vectors(langchain_instance):
# Test with add_embeddings method
langchain_instance.client.add_embeddings = Mock()
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
langchain_instance.client.add_embeddings.assert_called_once_with(
embeddings=vectors,
metadatas=payloads,
ids=ids
)
langchain_instance.client.add_embeddings.assert_called_once_with(embeddings=vectors, metadatas=payloads, ids=ids)
# Test with add_texts method
delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely
langchain_instance.client.add_texts = Mock()
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
langchain_instance.client.add_texts.assert_called_once_with(
texts=["text1", "text2"],
metadatas=payloads,
ids=ids
)
langchain_instance.client.add_texts.assert_called_once_with(texts=["text1", "text2"], metadatas=payloads, ids=ids)
# Test with empty payloads
langchain_instance.client.add_texts.reset_mock()
langchain_instance.insert(vectors=vectors, payloads=None, ids=ids)
langchain_instance.client.add_texts.assert_called_once_with(
texts=["", ""],
metadatas=None,
ids=ids
)
langchain_instance.client.add_texts.assert_called_once_with(texts=["", ""], metadatas=None, ids=ids)
def test_search_vectors(langchain_instance):
# Mock search results
mock_docs = [
Mock(metadata={"name": "vector1"}, id="id1"),
Mock(metadata={"name": "vector2"}, id="id2")
]
mock_docs = [Mock(metadata={"name": "vector1"}, id="id1"), Mock(metadata={"name": "vector2"}, id="id2")]
langchain_instance.client.similarity_search_by_vector.return_value = mock_docs
# Test search without filters
vectors = [[0.1, 0.2, 0.3]]
results = langchain_instance.search(query="", vectors=vectors, limit=2)
langchain_instance.client.similarity_search_by_vector.assert_called_once_with(
embedding=vectors,
k=2
)
langchain_instance.client.similarity_search_by_vector.assert_called_once_with(embedding=vectors, k=2)
assert len(results) == 2
assert results[0].id == "id1"
assert results[0].payload == {"name": "vector1"}
@@ -76,11 +61,8 @@ def test_search_vectors(langchain_instance):
# Test search with filters
filters = {"name": "vector1"}
langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters)
langchain_instance.client.similarity_search_by_vector.assert_called_with(
embedding=vectors,
k=2,
filter=filters
)
langchain_instance.client.similarity_search_by_vector.assert_called_with(embedding=vectors, k=2, filter=filters)
def test_get_vector(langchain_instance):
# Mock get result
@@ -90,7 +72,7 @@ def test_get_vector(langchain_instance):
# Test get existing vector
result = langchain_instance.get("id1")
langchain_instance.client.get_by_ids.assert_called_once_with(["id1"])
assert result is not None
assert result.id == "id1"
assert result.payload == {"name": "vector1"}

View File

@@ -8,9 +8,7 @@ import pytest
try:
from opensearchpy import AWSV4SignerAuth, OpenSearch
except ImportError:
raise ImportError(
"OpenSearch requires extra dependencies. Install with `pip install opensearch-py`"
) from None
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
from mem0.vector_stores.opensearch import OpenSearchDB
@@ -20,13 +18,13 @@ class TestOpenSearchDB(unittest.TestCase):
def setUpClass(cls):
dotenv.load_dotenv()
cls.original_env = {
'OS_URL': os.getenv('OS_URL', 'http://localhost:9200'),
'OS_USERNAME': os.getenv('OS_USERNAME', 'test_user'),
'OS_PASSWORD': os.getenv('OS_PASSWORD', 'test_password')
"OS_URL": os.getenv("OS_URL", "http://localhost:9200"),
"OS_USERNAME": os.getenv("OS_USERNAME", "test_user"),
"OS_PASSWORD": os.getenv("OS_PASSWORD", "test_password"),
}
os.environ['OS_URL'] = 'http://localhost'
os.environ['OS_USERNAME'] = 'test_user'
os.environ['OS_PASSWORD'] = 'test_password'
os.environ["OS_URL"] = "http://localhost"
os.environ["OS_USERNAME"] = "test_user"
os.environ["OS_PASSWORD"] = "test_password"
def setUp(self):
self.client_mock = MagicMock(spec=OpenSearch)
@@ -40,19 +38,19 @@ class TestOpenSearchDB(unittest.TestCase):
self.client_mock.delete = MagicMock()
self.client_mock.search = MagicMock()
patcher = patch('mem0.vector_stores.opensearch.OpenSearch', return_value=self.client_mock)
patcher = patch("mem0.vector_stores.opensearch.OpenSearch", return_value=self.client_mock)
self.mock_os = patcher.start()
self.addCleanup(patcher.stop)
self.os_db = OpenSearchDB(
host=os.getenv('OS_URL'),
host=os.getenv("OS_URL"),
port=9200,
collection_name="test_collection",
embedding_model_dims=1536,
user=os.getenv('OS_USERNAME'),
password=os.getenv('OS_PASSWORD'),
user=os.getenv("OS_USERNAME"),
password=os.getenv("OS_PASSWORD"),
verify_certs=False,
use_ssl=False
use_ssl=False,
)
self.client_mock.reset_mock()
@@ -86,29 +84,29 @@ class TestOpenSearchDB(unittest.TestCase):
vectors = [[0.1] * 1536, [0.2] * 1536]
payloads = [{"key1": "value1"}, {"key2": "value2"}]
ids = ["id1", "id2"]
# Mock the index method
self.client_mock.index = MagicMock()
results = self.os_db.insert(vectors=vectors, payloads=payloads, ids=ids)
# Verify index was called twice (once for each vector)
self.assertEqual(self.client_mock.index.call_count, 2)
# Check first call
first_call = self.client_mock.index.call_args_list[0]
self.assertEqual(first_call[1]["index"], "test_collection")
self.assertEqual(first_call[1]["body"]["vector_field"], vectors[0])
self.assertEqual(first_call[1]["body"]["payload"], payloads[0])
self.assertEqual(first_call[1]["body"]["id"], ids[0])
# Check second call
second_call = self.client_mock.index.call_args_list[1]
self.assertEqual(second_call[1]["index"], "test_collection")
self.assertEqual(second_call[1]["body"]["vector_field"], vectors[1])
self.assertEqual(second_call[1]["body"]["payload"], payloads[1])
self.assertEqual(second_call[1]["body"]["id"], ids[1])
# Check results
self.assertEqual(len(results), 2)
self.assertEqual(results[0].id, "id1")
@@ -132,7 +130,7 @@ class TestOpenSearchDB(unittest.TestCase):
self.client_mock.search.return_value = {"hits": {"hits": []}}
result = self.os_db.get("nonexistent")
self.assertIsNone(result)
def test_update(self):
vector = [0.3] * 1536
payload = {"key3": "value3"}
@@ -152,7 +150,17 @@ class TestOpenSearchDB(unittest.TestCase):
self.assertEqual(result, ["test_collection"])
def test_search(self):
mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}}}]}}
mock_response = {
"hits": {
"hits": [
{
"_id": "id1",
"_score": 0.8,
"_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}},
}
]
}
}
self.client_mock.search.return_value = mock_response
vectors = [[0.1] * 1536]
results = self.os_db.search(query="", vectors=vectors, limit=5)
@@ -179,12 +187,11 @@ class TestOpenSearchDB(unittest.TestCase):
self.os_db.delete_col()
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
def test_init_with_http_auth(self):
mock_credentials = MagicMock()
mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es")
with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch:
with patch("mem0.vector_stores.opensearch.OpenSearch") as mock_opensearch:
OpenSearchDB(
host="localhost",
port=9200,
@@ -192,7 +199,7 @@ class TestOpenSearchDB(unittest.TestCase):
embedding_model_dims=1536,
http_auth=mock_signer,
verify_certs=True,
use_ssl=True
use_ssl=True,
)
# Verify OpenSearch was initialized with correct params
@@ -202,5 +209,5 @@ class TestOpenSearchDB(unittest.TestCase):
use_ssl=True,
verify_certs=True,
connection_class=unittest.mock.ANY,
pool_maxsize=20
)
pool_maxsize=20,
)

View File

@@ -12,6 +12,7 @@ def mock_pinecone_client():
client.list_indexes.return_value.names.return_value = []
return client
@pytest.fixture
def pinecone_db(mock_pinecone_client):
return PineconeDB(
@@ -25,13 +26,14 @@ def pinecone_db(mock_pinecone_client):
hybrid_search=False,
metric="cosine",
batch_size=100,
extra_params=None
extra_params=None,
)
def test_create_col_existing_index(mock_pinecone_client):
# Set up the mock before creating the PineconeDB object
mock_pinecone_client.list_indexes.return_value.names.return_value = ["test_index"]
pinecone_db = PineconeDB(
collection_name="test_index",
embedding_model_dims=128,
@@ -43,21 +45,23 @@ def test_create_col_existing_index(mock_pinecone_client):
hybrid_search=False,
metric="cosine",
batch_size=100,
extra_params=None
extra_params=None,
)
# Reset the mock to verify it wasn't called during the test
mock_pinecone_client.create_index.reset_mock()
pinecone_db.create_col(128, "cosine")
mock_pinecone_client.create_index.assert_not_called()
def test_create_col_new_index(pinecone_db, mock_pinecone_client):
mock_pinecone_client.list_indexes.return_value.names.return_value = []
pinecone_db.create_col(128, "cosine")
mock_pinecone_client.create_index.assert_called()
def test_insert_vectors(pinecone_db):
vectors = [[0.1] * 128, [0.2] * 128]
payloads = [{"name": "vector1"}, {"name": "vector2"}]
@@ -65,56 +69,61 @@ def test_insert_vectors(pinecone_db):
pinecone_db.insert(vectors, payloads, ids)
pinecone_db.index.upsert.assert_called()
def test_search_vectors(pinecone_db):
pinecone_db.index.query.return_value.matches = [{"id": "id1", "score": 0.9, "metadata": {"name": "vector1"}}]
results = pinecone_db.search("test query",[0.1] * 128, limit=1)
results = pinecone_db.search("test query", [0.1] * 128, limit=1)
assert len(results) == 1
assert results[0].id == "id1"
assert results[0].score == 0.9
def test_update_vector(pinecone_db):
pinecone_db.update("id1", vector=[0.5] * 128, payload={"name": "updated"})
pinecone_db.index.upsert.assert_called()
def test_get_vector_found(pinecone_db):
# Looking at the _parse_output method, it expects a Vector object
# or a list of dictionaries, not a dictionary with an 'id' field
# Create a mock Vector object
from pinecone.data.dataclasses.vector import Vector
mock_vector = Vector(
id="id1",
values=[0.1] * 128,
metadata={"name": "vector1"}
)
mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"})
# Mock the fetch method to return the mock response object
mock_response = MagicMock()
mock_response.vectors = {"id1": mock_vector}
pinecone_db.index.fetch.return_value = mock_response
result = pinecone_db.get("id1")
assert result is not None
assert result.id == "id1"
assert result.payload == {"name": "vector1"}
def test_delete_vector(pinecone_db):
pinecone_db.delete("id1")
pinecone_db.index.delete.assert_called_with(ids=["id1"])
def test_get_vector_not_found(pinecone_db):
pinecone_db.index.fetch.return_value.vectors = {}
result = pinecone_db.get("id1")
assert result is None
def test_list_cols(pinecone_db):
pinecone_db.list_cols()
pinecone_db.client.list_indexes.assert_called()
def test_delete_col(pinecone_db):
pinecone_db.delete_col()
pinecone_db.client.delete_index.assert_called_with("test_index")
def test_col_info(pinecone_db):
pinecone_db.col_info()
pinecone_db.client.describe_index.assert_called_with("test_index")

View File

@@ -37,7 +37,7 @@ def supabase_instance(mock_vecs_client, mock_collection):
index_method=IndexMethod.HNSW,
index_measure=IndexMeasure.COSINE,
)
# Manually set the collection attribute since we're mocking the initialization
instance.collection = mock_collection
return instance
@@ -46,14 +46,8 @@ def supabase_instance(mock_vecs_client, mock_collection):
def test_create_col(supabase_instance, mock_vecs_client, mock_collection):
supabase_instance.create_col(1536)
mock_vecs_client.return_value.get_or_create_collection.assert_called_with(
name="test_collection",
dimension=1536
)
mock_collection.create_index.assert_called_with(
method="hnsw",
measure="cosine_distance"
)
mock_vecs_client.return_value.get_or_create_collection.assert_called_with(name="test_collection", dimension=1536)
mock_collection.create_index.assert_called_with(method="hnsw", measure="cosine_distance")
def test_insert_vectors(supabase_instance, mock_collection):
@@ -63,18 +57,12 @@ def test_insert_vectors(supabase_instance, mock_collection):
supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
expected_records = [
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
]
expected_records = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})]
mock_collection.upsert.assert_called_once_with(expected_records)
def test_search_vectors(supabase_instance, mock_collection):
mock_results = [
("id1", 0.9, {"name": "vector1"}),
("id2", 0.8, {"name": "vector2"})
]
mock_results = [("id1", 0.9, {"name": "vector1"}), ("id2", 0.8, {"name": "vector2"})]
mock_collection.query.return_value = mock_results
vectors = [[0.1, 0.2, 0.3]]
@@ -82,11 +70,7 @@ def test_search_vectors(supabase_instance, mock_collection):
results = supabase_instance.search(query="", vectors=vectors, limit=2, filters=filters)
mock_collection.query.assert_called_once_with(
data=vectors,
limit=2,
filters={"category": {"$eq": "test"}},
include_metadata=True,
include_value=True
data=vectors, limit=2, filters={"category": {"$eq": "test"}}, include_metadata=True, include_value=True
)
assert len(results) == 2
@@ -129,11 +113,8 @@ def test_get_vector(supabase_instance, mock_collection):
def test_list_vectors(supabase_instance, mock_collection):
mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})]
mock_fetch_results = [
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
]
mock_fetch_results = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})]
mock_collection.query.return_value = mock_query_results
mock_collection.fetch.return_value = mock_fetch_results
@@ -153,10 +134,7 @@ def test_col_info(supabase_instance, mock_collection):
"name": "test_collection",
"count": 100,
"dimension": 1536,
"index": {
"method": "hnsw",
"metric": "cosine_distance"
}
"index": {"method": "hnsw", "metric": "cosine_distance"},
}
@@ -168,10 +146,7 @@ def test_preprocess_filters(supabase_instance):
# Test multiple filters
multi_filter = {"category": "test", "type": "document"}
assert supabase_instance._preprocess_filters(multi_filter) == {
"$and": [
{"category": {"$eq": "test"}},
{"type": {"$eq": "document"}}
]
"$and": [{"category": {"$eq": "test"}}, {"type": {"$eq": "document"}}]
}
# Test None filters

View File

@@ -29,9 +29,7 @@ def upstash_instance(mock_index):
@pytest.fixture
def upstash_instance_with_embeddings(mock_index):
return UpstashVector(
client=mock_index.return_value, collection_name="ns", enable_embeddings=True
)
return UpstashVector(client=mock_index.return_value, collection_name="ns", enable_embeddings=True)
def test_insert_vectors(upstash_instance, mock_index):
@@ -52,12 +50,8 @@ def test_insert_vectors(upstash_instance, mock_index):
def test_search_vectors(upstash_instance, mock_index):
mock_result = [
QueryResult(
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
),
QueryResult(
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None
),
QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None),
QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None),
]
upstash_instance.client.query_many.return_value = [mock_result]
@@ -93,9 +87,7 @@ def test_delete_vector(upstash_instance):
upstash_instance.delete(vector_id=vector_id)
upstash_instance.client.delete.assert_called_once_with(
ids=[vector_id], namespace="ns"
)
upstash_instance.client.delete.assert_called_once_with(ids=[vector_id], namespace="ns")
def test_update_vector(upstash_instance):
@@ -115,18 +107,12 @@ def test_update_vector(upstash_instance):
def test_get_vector(upstash_instance):
mock_result = [
QueryResult(
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
)
]
mock_result = [QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None)]
upstash_instance.client.fetch.return_value = mock_result
result = upstash_instance.get(vector_id="id1")
upstash_instance.client.fetch.assert_called_once_with(
ids=["id1"], namespace="ns", include_metadata=True
)
upstash_instance.client.fetch.assert_called_once_with(ids=["id1"], namespace="ns", include_metadata=True)
assert result.id == "id1"
assert result.payload == {"name": "vector1"}
@@ -134,15 +120,9 @@ def test_get_vector(upstash_instance):
def test_list_vectors(upstash_instance):
mock_result = [
QueryResult(
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
),
QueryResult(
id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None
),
QueryResult(
id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None
),
QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None),
QueryResult(id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None),
QueryResult(id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None),
]
handler = MagicMock()
@@ -204,12 +184,8 @@ def test_insert_vectors_with_embeddings(upstash_instance_with_embeddings, mock_i
def test_search_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index):
mock_result = [
QueryResult(
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"
),
QueryResult(
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"
),
QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"),
QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"),
]
upstash_instance_with_embeddings.client.query.return_value = mock_result
@@ -260,9 +236,7 @@ def test_insert_vectors_with_embeddings_missing_data(upstash_instance_with_embed
ValueError,
match="When embeddings are enabled, all payloads must contain a 'data' field",
):
upstash_instance_with_embeddings.insert(
vectors=vectors, payloads=payloads, ids=ids
)
upstash_instance_with_embeddings.insert(vectors=vectors, payloads=payloads, ids=ids)
def test_update_vector_with_embeddings_missing_data(upstash_instance_with_embeddings):
@@ -316,18 +290,12 @@ def test_get_vector_not_found(upstash_instance):
result = upstash_instance.get(vector_id="nonexistent")
upstash_instance.client.fetch.assert_called_once_with(
ids=["nonexistent"], namespace="ns", include_metadata=True
)
upstash_instance.client.fetch.assert_called_once_with(ids=["nonexistent"], namespace="ns", include_metadata=True)
assert result is None
def test_search_vectors_empty_filters(upstash_instance):
mock_result = [
QueryResult(
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
)
]
mock_result = [QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None)]
upstash_instance.client.query_many.return_value = [mock_result]
vectors = [[0.1, 0.2, 0.3]]

View File

@@ -14,47 +14,50 @@ from mem0.vector_stores.vertex_ai_vector_search import GoogleMatchingEngine
@pytest.fixture
def mock_vertex_ai():
with patch('google.cloud.aiplatform.MatchingEngineIndex') as mock_index, \
patch('google.cloud.aiplatform.MatchingEngineIndexEndpoint') as mock_endpoint, \
patch('google.cloud.aiplatform.init') as mock_init:
with (
patch("google.cloud.aiplatform.MatchingEngineIndex") as mock_index,
patch("google.cloud.aiplatform.MatchingEngineIndexEndpoint") as mock_endpoint,
patch("google.cloud.aiplatform.init") as mock_init,
):
mock_index_instance = Mock()
mock_endpoint_instance = Mock()
yield {
'index': mock_index_instance,
'endpoint': mock_endpoint_instance,
'init': mock_init,
'index_class': mock_index,
'endpoint_class': mock_endpoint
"index": mock_index_instance,
"endpoint": mock_endpoint_instance,
"init": mock_init,
"index_class": mock_index,
"endpoint_class": mock_endpoint,
}
@pytest.fixture
def config():
return GoogleMatchingEngineConfig(
project_id='test-project',
project_number='123456789',
region='us-central1',
endpoint_id='test-endpoint',
index_id='test-index',
deployment_index_id='test-deployment',
collection_name='test-collection',
vector_search_api_endpoint='test.vertexai.goog'
project_id="test-project",
project_number="123456789",
region="us-central1",
endpoint_id="test-endpoint",
index_id="test-index",
deployment_index_id="test-deployment",
collection_name="test-collection",
vector_search_api_endpoint="test.vertexai.goog",
)
@pytest.fixture
def vector_store(config, mock_vertex_ai):
mock_vertex_ai['index_class'].return_value = mock_vertex_ai['index']
mock_vertex_ai['endpoint_class'].return_value = mock_vertex_ai['endpoint']
mock_vertex_ai["index_class"].return_value = mock_vertex_ai["index"]
mock_vertex_ai["endpoint_class"].return_value = mock_vertex_ai["endpoint"]
return GoogleMatchingEngine(**config.model_dump())
def test_initialization(vector_store, mock_vertex_ai, config):
"""Test proper initialization of GoogleMatchingEngine"""
mock_vertex_ai['init'].assert_called_once_with(
project=config.project_id,
location=config.region
)
mock_vertex_ai["init"].assert_called_once_with(project=config.project_id, location=config.region)
expected_index_path = f"projects/{config.project_number}/locations/{config.region}/indexes/{config.index_id}"
mock_vertex_ai['index_class'].assert_called_once_with(index_name=expected_index_path)
mock_vertex_ai["index_class"].assert_called_once_with(index_name=expected_index_path)
def test_insert_vectors(vector_store, mock_vertex_ai):
"""Test inserting vectors with payloads"""
@@ -64,13 +67,14 @@ def test_insert_vectors(vector_store, mock_vertex_ai):
vector_store.insert(vectors=vectors, payloads=payloads, ids=ids)
mock_vertex_ai['index'].upsert_datapoints.assert_called_once()
call_args = mock_vertex_ai['index'].upsert_datapoints.call_args[1]
assert len(call_args['datapoints']) == 1
datapoint_str = str(call_args['datapoints'][0])
mock_vertex_ai["index"].upsert_datapoints.assert_called_once()
call_args = mock_vertex_ai["index"].upsert_datapoints.call_args[1]
assert len(call_args["datapoints"]) == 1
datapoint_str = str(call_args["datapoints"][0])
assert "test-id" in datapoint_str
assert "0.1" in datapoint_str and "0.2" in datapoint_str and "0.3" in datapoint_str
def test_search_vectors(vector_store, mock_vertex_ai):
"""Test searching vectors with filters"""
vectors = [[0.1, 0.2, 0.3]]
@@ -85,7 +89,7 @@ def test_search_vectors(vector_store, mock_vertex_ai):
mock_restrict.allow_list = ["test_user"]
mock_restrict.name = "user_id"
mock_restrict.allow_tokens = ["test_user"]
mock_datapoint.restricts = [mock_restrict]
mock_neighbor = Mock()
@@ -94,16 +98,16 @@ def test_search_vectors(vector_store, mock_vertex_ai):
mock_neighbor.datapoint = mock_datapoint
mock_neighbor.restricts = [mock_restrict]
mock_vertex_ai['endpoint'].find_neighbors.return_value = [[mock_neighbor]]
mock_vertex_ai["endpoint"].find_neighbors.return_value = [[mock_neighbor]]
results = vector_store.search(query="", vectors=vectors, filters=filters, limit=1)
mock_vertex_ai['endpoint'].find_neighbors.assert_called_once_with(
mock_vertex_ai["endpoint"].find_neighbors.assert_called_once_with(
deployed_index_id=vector_store.deployment_index_id,
queries=[vectors],
num_neighbors=1,
filter=[Namespace("user_id", ["test_user"], [])],
return_full_datapoint=True
return_full_datapoint=True,
)
assert len(results) == 1
@@ -111,29 +115,27 @@ def test_search_vectors(vector_store, mock_vertex_ai):
assert results[0].score == 0.1
assert results[0].payload == {"user_id": "test_user"}
def test_delete(vector_store, mock_vertex_ai):
"""Test deleting vectors"""
vector_id = "test-id"
remove_mock = Mock()
with patch.object(GoogleMatchingEngine, 'delete', wraps=vector_store.delete) as delete_spy:
with patch.object(vector_store.index, 'remove_datapoints', remove_mock):
with patch.object(GoogleMatchingEngine, "delete", wraps=vector_store.delete) as delete_spy:
with patch.object(vector_store.index, "remove_datapoints", remove_mock):
vector_store.delete(ids=[vector_id])
delete_spy.assert_called_once_with(ids=[vector_id])
remove_mock.assert_called_once_with(datapoint_ids=[vector_id])
def test_error_handling(vector_store, mock_vertex_ai):
"""Test error handling during operations"""
mock_vertex_ai['index'].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request")
mock_vertex_ai["index"].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request")
with pytest.raises(Exception) as exc_info:
vector_store.insert(
vectors=[[0.1, 0.2, 0.3]],
payloads=[{"name": "test"}],
ids=["test-id"]
)
vector_store.insert(vectors=[[0.1, 0.2, 0.3]], payloads=[{"name": "test"}], ids=["test-id"])
assert isinstance(exc_info.value, exceptions.InvalidArgument)
assert "Invalid request" in str(exc_info.value)

View File

@@ -76,15 +76,15 @@
# self.client_mock.batch = MagicMock()
# self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock()
# self.client_mock.collections.get.return_value.data.insert_many.return_value = {
# "results": [{"id": "id1"}, {"id": "id2"}]
# }
# vectors = [[0.1] * 1536, [0.2] * 1536]
# payloads = [{"key1": "value1"}, {"key2": "value2"}]
# ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids)
# def test_get(self):
@@ -108,7 +108,7 @@
# result = self.weaviate_db.get(vector_id=valid_uuid)
# assert result.id == valid_uuid
# expected_payload = mock_response.properties.copy()
# expected_payload["id"] = valid_uuid
@@ -131,10 +131,10 @@
# "metadata": {"distance": 0.2}
# }
# ]
# mock_response = MagicMock()
# mock_response.objects = []
# for obj in mock_objects:
# mock_obj = MagicMock()
# mock_obj.uuid = obj["uuid"]
@@ -142,16 +142,16 @@
# mock_obj.metadata = MagicMock()
# mock_obj.metadata.distance = obj["metadata"]["distance"]
# mock_response.objects.append(mock_obj)
# mock_hybrid = MagicMock()
# self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
# mock_hybrid.return_value = mock_response
# vectors = [[0.1] * 1536]
# results = self.weaviate_db.search(query="", vectors=vectors, limit=5)
# mock_hybrid.assert_called_once()
# self.assertEqual(len(results), 1)
# self.assertEqual(results[0].id, "id1")
# self.assertEqual(results[0].score, 0.8)
@@ -163,28 +163,28 @@
# def test_list(self):
# mock_objects = []
# mock_obj1 = MagicMock()
# mock_obj1.uuid = "id1"
# mock_obj1.properties = {"key1": "value1"}
# mock_objects.append(mock_obj1)
# mock_obj2 = MagicMock()
# mock_obj2.uuid = "id2"
# mock_obj2.properties = {"key2": "value2"}
# mock_objects.append(mock_obj2)
# mock_response = MagicMock()
# mock_response.objects = mock_objects
# mock_fetch = MagicMock()
# self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch
# mock_fetch.return_value = mock_response
# results = self.weaviate_db.list(limit=10)
# mock_fetch.assert_called_once()
# # Verify results
# self.assertEqual(len(results), 1)
# self.assertEqual(len(results[0]), 2)