From 3e66ddf69ab0e4534df4fd4303e8039d74865009 Mon Sep 17 00:00:00 2001 From: sw8fbar Date: Mon, 4 Sep 2023 15:49:59 -0500 Subject: [PATCH] feat: where filter in vector database (#518) --- embedchain/apps/PersonApp.py | 11 +++--- embedchain/config/ChatConfig.py | 3 ++ embedchain/config/QueryConfig.py | 3 ++ embedchain/embedchain.py | 25 +++++++++--- tests/embedchain/test_chat.py | 66 +++++++++++++++++++++++++++++++- tests/embedchain/test_query.py | 62 ++++++++++++++++++++++++++++++ 6 files changed, 156 insertions(+), 14 deletions(-) diff --git a/embedchain/apps/PersonApp.py b/embedchain/apps/PersonApp.py index d38719d6..a804fa6b 100644 --- a/embedchain/apps/PersonApp.py +++ b/embedchain/apps/PersonApp.py @@ -4,8 +4,7 @@ from embedchain.apps.App import App from embedchain.apps.OpenSourceApp import OpenSourceApp from embedchain.config import ChatConfig, QueryConfig from embedchain.config.apps.BaseAppConfig import BaseAppConfig -from embedchain.config.QueryConfig import (DEFAULT_PROMPT, - DEFAULT_PROMPT_WITH_HISTORY) +from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY from embedchain.helper_classes.json_serializable import register_deserializable @@ -60,12 +59,12 @@ class PersonApp(EmbedChainPersonApp, App): """ def query(self, input_query, config: QueryConfig = None, dry_run=False): - config = self.add_person_template_to_config(DEFAULT_PROMPT, config) - return super().query(input_query, config, dry_run) + config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None) + return super().query(input_query, config, dry_run, where=None) - def chat(self, input_query, config: ChatConfig = None, dry_run=False): + def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None): config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config) - return super().chat(input_query, config, dry_run) + return super().chat(input_query, config, dry_run, where) @register_deserializable diff --git a/embedchain/config/ChatConfig.py b/embedchain/config/ChatConfig.py index 0c403869..c6bbcc48 100644 --- a/embedchain/config/ChatConfig.py +++ b/embedchain/config/ChatConfig.py @@ -38,6 +38,7 @@ class ChatConfig(QueryConfig): stream: bool = False, deployment_name=None, system_prompt: Optional[str] = None, + where=None, ): """ Initializes the ChatConfig instance. @@ -57,6 +58,7 @@ class ChatConfig(QueryConfig): :param stream: Optional. Control if response is streamed back to the user :param deployment_name: t.b.a. :param system_prompt: Optional. System prompt string. + :param where: Optional. A dictionary of key-value pairs to filter the database results. :raises ValueError: If the template is not valid as template should contain $context and $query and $history """ @@ -77,6 +79,7 @@ class ChatConfig(QueryConfig): stream=stream, deployment_name=deployment_name, system_prompt=system_prompt, + where=where, ) def set_history(self, history): diff --git a/embedchain/config/QueryConfig.py b/embedchain/config/QueryConfig.py index a8c703a4..b4c29882 100644 --- a/embedchain/config/QueryConfig.py +++ b/embedchain/config/QueryConfig.py @@ -67,6 +67,7 @@ class QueryConfig(BaseConfig): stream: bool = False, deployment_name=None, system_prompt: Optional[str] = None, + where=None, ): """ Initializes the QueryConfig instance. @@ -87,6 +88,7 @@ class QueryConfig(BaseConfig): :param stream: Optional. Control if response is streamed back to user :param deployment_name: t.b.a. :param system_prompt: Optional. System prompt string. + :param where: Optional. A dictionary of key-value pairs to filter the database results. :raises ValueError: If the template is not valid as template should contain $context and $query (and optionally $history). """ @@ -127,6 +129,7 @@ class QueryConfig(BaseConfig): if not isinstance(stream, bool): raise ValueError("`stream` should be bool") self.stream = stream + self.where = where def validate_template(self, template: Template): """ diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index d379863b..9f30b9c5 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -250,16 +250,27 @@ class EmbedChain(JSONSerializable): """ raise NotImplementedError - def retrieve_from_database(self, input_query, config: QueryConfig): + def retrieve_from_database(self, input_query, config: QueryConfig, where=None): """ Queries the vector database based on the given input query. Gets relevant doc based on the query :param input_query: The query to use. :param config: The query configuration. + :param where: Optional. A dictionary of key-value pairs to filter the database results. :return: The content of the document that matched your query. """ - where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter + + if where is not None: + where = where + elif config is not None and config.where is not None: + where = config.where + else: + where = {} + + if self.config.id is not None: + where.update({"app_id": self.config.id}) + contents = self.db.query( input_query=input_query, n_results=config.number_documents, @@ -311,7 +322,7 @@ class EmbedChain(JSONSerializable): logging.info(f"Access search to get answers for {input_query}") return search.run(input_query) - def query(self, input_query, config: QueryConfig = None, dry_run=False): + def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None): """ Queries the vector database based on the given input query. Gets relevant doc based on the query and then passes it to an @@ -326,6 +337,7 @@ class EmbedChain(JSONSerializable): by the vector database's doc retrieval. The only thing the dry run does not consider is the cut-off due to the `max_tokens` parameter. + :param where: Optional. A dictionary of key-value pairs to filter the database results. :return: The answer to the query. """ if config is None: @@ -336,7 +348,7 @@ class EmbedChain(JSONSerializable): k = {} if self.online: k["web_search_result"] = self.access_search_and_get_results(input_query) - contexts = self.retrieve_from_database(input_query, config) + contexts = self.retrieve_from_database(input_query, config, where) prompt = self.generate_prompt(input_query, contexts, config, **k) logging.info(f"Prompt: {prompt}") @@ -362,7 +374,7 @@ class EmbedChain(JSONSerializable): yield chunk logging.info(f"Answer: {streamed_answer}") - def chat(self, input_query, config: ChatConfig = None, dry_run=False): + def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None): """ Queries the vector database on the given input query. Gets relevant doc based on the query and then passes it to an @@ -378,6 +390,7 @@ class EmbedChain(JSONSerializable): by the vector database's doc retrieval. The only thing the dry run does not consider is the cut-off due to the `max_tokens` parameter. + :param where: Optional. A dictionary of key-value pairs to filter the database results. :return: The answer to the query. """ if config is None: @@ -388,7 +401,7 @@ class EmbedChain(JSONSerializable): k = {} if self.online: k["web_search_result"] = self.access_search_and_get_results(input_query) - contexts = self.retrieve_from_database(input_query, config) + contexts = self.retrieve_from_database(input_query, config, where) chat_history = self.memory.load_memory_variables({})["history"] diff --git a/tests/embedchain/test_chat.py b/tests/embedchain/test_chat.py index fe49a050..874f698f 100644 --- a/tests/embedchain/test_chat.py +++ b/tests/embedchain/test_chat.py @@ -1,9 +1,9 @@ import os import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch from embedchain import App -from embedchain.config import AppConfig +from embedchain.config import AppConfig, ChatConfig class TestApp(unittest.TestCase): @@ -35,3 +35,65 @@ class TestApp(unittest.TestCase): second_answer = app.chat("Test query 2") self.assertEqual(second_answer, "Test answer") self.assertEqual(len(app.memory.chat_memory.messages), 4) + + @patch("chromadb.api.models.Collection.Collection.add", MagicMock) + def test_chat_with_where_in_params(self): + """ + This test checks the functionality of the 'chat' method in the App class. + It simulates a scenario where the 'retrieve_from_database' method returns a context list based on + a where filter and 'get_llm_model_answer' returns an expected answer string. + + The 'chat' method is expected to call 'retrieve_from_database' with the where filter and + 'get_llm_model_answer' methods appropriately and return the right answer. + + Key assumptions tested: + - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of + QueryConfig. + - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test. + - 'chat' method returns the value it received from 'get_llm_model_answer'. + + The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and + 'get_llm_model_answer' methods. + """ + with patch.object(self.app, "retrieve_from_database") as mock_retrieve: + mock_retrieve.return_value = ["Test context"] + with patch.object(self.app, "get_llm_model_answer") as mock_answer: + mock_answer.return_value = "Test answer" + answer = self.app.chat("Test chat", where={"attribute": "value"}) + + self.assertEqual(answer, "Test answer") + self.assertEqual(mock_retrieve.call_args[0][0], "Test chat") + self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"}) + mock_answer.assert_called_once() + + @patch("chromadb.api.models.Collection.Collection.add", MagicMock) + def test_chat_with_where_in_chat_config(self): + """ + This test checks the functionality of the 'chat' method in the App class. + It simulates a scenario where the 'retrieve_from_database' method returns a context list based on + a where filter and 'get_llm_model_answer' returns an expected answer string. + + The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified + in the QueryConfig and 'get_llm_model_answer' methods appropriately and return the right answer. + + Key assumptions tested: + - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of + QueryConfig. + - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test. + - 'chat' method returns the value it received from 'get_llm_model_answer'. + + The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and + 'get_llm_model_answer' methods. + """ + with patch.object(self.app, "retrieve_from_database") as mock_retrieve: + mock_retrieve.return_value = ["Test context"] + with patch.object(self.app, "get_llm_model_answer") as mock_answer: + mock_answer.return_value = "Test answer" + chatConfig = ChatConfig(where={"attribute": "value"}) + answer = self.app.chat("Test chat", chatConfig) + + self.assertEqual(answer, "Test answer") + self.assertEqual(mock_retrieve.call_args[0][0], "Test chat") + self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"}) + self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig) + mock_answer.assert_called_once() diff --git a/tests/embedchain/test_query.py b/tests/embedchain/test_query.py index 46be1427..521686f6 100644 --- a/tests/embedchain/test_query.py +++ b/tests/embedchain/test_query.py @@ -73,3 +73,65 @@ class TestApp(unittest.TestCase): messages_arg = mock_create.call_args.kwargs["messages"] self.assertEqual(messages_arg[0]["role"], "system") self.assertEqual(messages_arg[0]["content"], "Test system prompt") + + @patch("chromadb.api.models.Collection.Collection.add", MagicMock) + def test_query_with_where_in_params(self): + """ + This test checks the functionality of the 'query' method in the App class. + It simulates a scenario where the 'retrieve_from_database' method returns a context list based on + a where filter and 'get_llm_model_answer' returns an expected answer string. + + The 'query' method is expected to call 'retrieve_from_database' with the where filter and + 'get_llm_model_answer' methods appropriately and return the right answer. + + Key assumptions tested: + - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of + QueryConfig. + - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test. + - 'query' method returns the value it received from 'get_llm_model_answer'. + + The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and + 'get_llm_model_answer' methods. + """ + with patch.object(self.app, "retrieve_from_database") as mock_retrieve: + mock_retrieve.return_value = ["Test context"] + with patch.object(self.app, "get_llm_model_answer") as mock_answer: + mock_answer.return_value = "Test answer" + answer = self.app.query("Test query", where={"attribute": "value"}) + + self.assertEqual(answer, "Test answer") + self.assertEqual(mock_retrieve.call_args[0][0], "Test query") + self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"}) + mock_answer.assert_called_once() + + @patch("chromadb.api.models.Collection.Collection.add", MagicMock) + def test_query_with_where_in_query_config(self): + """ + This test checks the functionality of the 'query' method in the App class. + It simulates a scenario where the 'retrieve_from_database' method returns a context list based on + a where filter and 'get_llm_model_answer' returns an expected answer string. + + The 'query' method is expected to call 'retrieve_from_database' with the where filter and + 'get_llm_model_answer' methods appropriately and return the right answer. + + Key assumptions tested: + - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of + QueryConfig. + - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test. + - 'query' method returns the value it received from 'get_llm_model_answer'. + + The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and + 'get_llm_model_answer' methods. + """ + with patch.object(self.app, "retrieve_from_database") as mock_retrieve: + mock_retrieve.return_value = ["Test context"] + with patch.object(self.app, "get_llm_model_answer") as mock_answer: + mock_answer.return_value = "Test answer" + queryConfig = QueryConfig(where={"attribute": "value"}) + answer = self.app.query("Test query", queryConfig) + + self.assertEqual(answer, "Test answer") + self.assertEqual(mock_retrieve.call_args[0][0], "Test query") + self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"}) + self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig) + mock_answer.assert_called_once()