feat: where filter in vector database (#518)
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config import AppConfig, ChatConfig
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
@@ -35,3 +35,65 @@ class TestApp(unittest.TestCase):
|
||||
second_answer = app.chat("Test query 2")
|
||||
self.assertEqual(second_answer, "Test answer")
|
||||
self.assertEqual(len(app.memory.chat_memory.messages), 4)
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_params(self):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'chat' method is expected to call 'retrieve_from_database' with the where filter and
|
||||
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.chat("Test chat", where={"attribute": "value"})
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
|
||||
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_chat_config(self):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified
|
||||
in the QueryConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
chatConfig = ChatConfig(where={"attribute": "value"})
|
||||
answer = self.app.chat("Test chat", chatConfig)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
|
||||
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
|
||||
self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig)
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@@ -73,3 +73,65 @@ class TestApp(unittest.TestCase):
|
||||
messages_arg = mock_create.call_args.kwargs["messages"]
|
||||
self.assertEqual(messages_arg[0]["role"], "system")
|
||||
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_params(self):
|
||||
"""
|
||||
This test checks the functionality of the 'query' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
|
||||
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.query("Test query", where={"attribute": "value"})
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_query_config(self):
|
||||
"""
|
||||
This test checks the functionality of the 'query' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
|
||||
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
queryConfig = QueryConfig(where={"attribute": "value"})
|
||||
answer = self.app.query("Test query", queryConfig)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
|
||||
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user