feat: where filter in vector database (#518)
This commit is contained in:
@@ -4,8 +4,7 @@ from embedchain.apps.App import App
|
||||
from embedchain.apps.OpenSourceApp import OpenSourceApp
|
||||
from embedchain.config import ChatConfig, QueryConfig
|
||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY)
|
||||
from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@@ -60,12 +59,12 @@ class PersonApp(EmbedChainPersonApp, App):
|
||||
"""
|
||||
|
||||
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
||||
config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
|
||||
return super().query(input_query, config, dry_run)
|
||||
config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
|
||||
return super().query(input_query, config, dry_run, where=None)
|
||||
|
||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
|
||||
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
|
||||
return super().chat(input_query, config, dry_run)
|
||||
return super().chat(input_query, config, dry_run, where)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -38,6 +38,7 @@ class ChatConfig(QueryConfig):
|
||||
stream: bool = False,
|
||||
deployment_name=None,
|
||||
system_prompt: Optional[str] = None,
|
||||
where=None,
|
||||
):
|
||||
"""
|
||||
Initializes the ChatConfig instance.
|
||||
@@ -57,6 +58,7 @@ class ChatConfig(QueryConfig):
|
||||
:param stream: Optional. Control if response is streamed back to the user
|
||||
:param deployment_name: t.b.a.
|
||||
:param system_prompt: Optional. System prompt string.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:raises ValueError: If the template is not valid as template should contain
|
||||
$context and $query and $history
|
||||
"""
|
||||
@@ -77,6 +79,7 @@ class ChatConfig(QueryConfig):
|
||||
stream=stream,
|
||||
deployment_name=deployment_name,
|
||||
system_prompt=system_prompt,
|
||||
where=where,
|
||||
)
|
||||
|
||||
def set_history(self, history):
|
||||
|
||||
@@ -67,6 +67,7 @@ class QueryConfig(BaseConfig):
|
||||
stream: bool = False,
|
||||
deployment_name=None,
|
||||
system_prompt: Optional[str] = None,
|
||||
where=None,
|
||||
):
|
||||
"""
|
||||
Initializes the QueryConfig instance.
|
||||
@@ -87,6 +88,7 @@ class QueryConfig(BaseConfig):
|
||||
:param stream: Optional. Control if response is streamed back to user
|
||||
:param deployment_name: t.b.a.
|
||||
:param system_prompt: Optional. System prompt string.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:raises ValueError: If the template is not valid as template should
|
||||
contain $context and $query (and optionally $history).
|
||||
"""
|
||||
@@ -127,6 +129,7 @@ class QueryConfig(BaseConfig):
|
||||
if not isinstance(stream, bool):
|
||||
raise ValueError("`stream` should be bool")
|
||||
self.stream = stream
|
||||
self.where = where
|
||||
|
||||
def validate_template(self, template: Template):
|
||||
"""
|
||||
|
||||
@@ -250,16 +250,27 @@ class EmbedChain(JSONSerializable):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def retrieve_from_database(self, input_query, config: QueryConfig):
|
||||
def retrieve_from_database(self, input_query, config: QueryConfig, where=None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: The query configuration.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The content of the document that matched your query.
|
||||
"""
|
||||
where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter
|
||||
|
||||
if where is not None:
|
||||
where = where
|
||||
elif config is not None and config.where is not None:
|
||||
where = config.where
|
||||
else:
|
||||
where = {}
|
||||
|
||||
if self.config.id is not None:
|
||||
where.update({"app_id": self.config.id})
|
||||
|
||||
contents = self.db.query(
|
||||
input_query=input_query,
|
||||
n_results=config.number_documents,
|
||||
@@ -311,7 +322,7 @@ class EmbedChain(JSONSerializable):
|
||||
logging.info(f"Access search to get answers for {input_query}")
|
||||
return search.run(input_query)
|
||||
|
||||
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
||||
def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -326,6 +337,7 @@ class EmbedChain(JSONSerializable):
|
||||
by the vector database's doc retrieval.
|
||||
The only thing the dry run does not consider is the cut-off due to
|
||||
the `max_tokens` parameter.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
if config is None:
|
||||
@@ -336,7 +348,7 @@ class EmbedChain(JSONSerializable):
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
contexts = self.retrieve_from_database(input_query, config)
|
||||
contexts = self.retrieve_from_database(input_query, config, where)
|
||||
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
@@ -362,7 +374,7 @@ class EmbedChain(JSONSerializable):
|
||||
yield chunk
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -378,6 +390,7 @@ class EmbedChain(JSONSerializable):
|
||||
by the vector database's doc retrieval.
|
||||
The only thing the dry run does not consider is the cut-off due to
|
||||
the `max_tokens` parameter.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
if config is None:
|
||||
@@ -388,7 +401,7 @@ class EmbedChain(JSONSerializable):
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
contexts = self.retrieve_from_database(input_query, config)
|
||||
contexts = self.retrieve_from_database(input_query, config, where)
|
||||
|
||||
chat_history = self.memory.load_memory_variables({})["history"]
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config import AppConfig, ChatConfig
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
@@ -35,3 +35,65 @@ class TestApp(unittest.TestCase):
|
||||
second_answer = app.chat("Test query 2")
|
||||
self.assertEqual(second_answer, "Test answer")
|
||||
self.assertEqual(len(app.memory.chat_memory.messages), 4)
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_params(self):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'chat' method is expected to call 'retrieve_from_database' with the where filter and
|
||||
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.chat("Test chat", where={"attribute": "value"})
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
|
||||
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_chat_config(self):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified
|
||||
in the QueryConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
chatConfig = ChatConfig(where={"attribute": "value"})
|
||||
answer = self.app.chat("Test chat", chatConfig)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
|
||||
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
|
||||
self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig)
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@@ -73,3 +73,65 @@ class TestApp(unittest.TestCase):
|
||||
messages_arg = mock_create.call_args.kwargs["messages"]
|
||||
self.assertEqual(messages_arg[0]["role"], "system")
|
||||
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_params(self):
|
||||
"""
|
||||
This test checks the functionality of the 'query' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
|
||||
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.query("Test query", where={"attribute": "value"})
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_query_config(self):
|
||||
"""
|
||||
This test checks the functionality of the 'query' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
|
||||
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
queryConfig = QueryConfig(where={"attribute": "value"})
|
||||
answer = self.app.query("Test query", queryConfig)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
|
||||
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user