feat: where filter in vector database (#518)

This commit is contained in:
sw8fbar
2023-09-04 15:49:59 -05:00
committed by GitHub
parent 202fd2d5b6
commit 3e66ddf69a
6 changed files with 156 additions and 14 deletions

View File

@@ -4,8 +4,7 @@ from embedchain.apps.App import App
from embedchain.apps.OpenSourceApp import OpenSourceApp
from embedchain.config import ChatConfig, QueryConfig
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY)
from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
from embedchain.helper_classes.json_serializable import register_deserializable
@@ -60,12 +59,12 @@ class PersonApp(EmbedChainPersonApp, App):
"""
def query(self, input_query, config: QueryConfig = None, dry_run=False):
config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
return super().query(input_query, config, dry_run)
config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
return super().query(input_query, config, dry_run, where=None)
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
return super().chat(input_query, config, dry_run)
return super().chat(input_query, config, dry_run, where)
@register_deserializable

View File

@@ -38,6 +38,7 @@ class ChatConfig(QueryConfig):
stream: bool = False,
deployment_name=None,
system_prompt: Optional[str] = None,
where=None,
):
"""
Initializes the ChatConfig instance.
@@ -57,6 +58,7 @@ class ChatConfig(QueryConfig):
:param stream: Optional. Control if response is streamed back to the user
:param deployment_name: t.b.a.
:param system_prompt: Optional. System prompt string.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:raises ValueError: If the template is not valid as template should contain
$context and $query and $history
"""
@@ -77,6 +79,7 @@ class ChatConfig(QueryConfig):
stream=stream,
deployment_name=deployment_name,
system_prompt=system_prompt,
where=where,
)
def set_history(self, history):

View File

@@ -67,6 +67,7 @@ class QueryConfig(BaseConfig):
stream: bool = False,
deployment_name=None,
system_prompt: Optional[str] = None,
where=None,
):
"""
Initializes the QueryConfig instance.
@@ -87,6 +88,7 @@ class QueryConfig(BaseConfig):
:param stream: Optional. Control if response is streamed back to user
:param deployment_name: t.b.a.
:param system_prompt: Optional. System prompt string.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:raises ValueError: If the template is not valid as template should
contain $context and $query (and optionally $history).
"""
@@ -127,6 +129,7 @@ class QueryConfig(BaseConfig):
if not isinstance(stream, bool):
raise ValueError("`stream` should be bool")
self.stream = stream
self.where = where
def validate_template(self, template: Template):
"""

View File

@@ -250,16 +250,27 @@ class EmbedChain(JSONSerializable):
"""
raise NotImplementedError
def retrieve_from_database(self, input_query, config: QueryConfig):
def retrieve_from_database(self, input_query, config: QueryConfig, where=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query
:param input_query: The query to use.
:param config: The query configuration.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The content of the document that matched your query.
"""
where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter
if where is not None:
where = where
elif config is not None and config.where is not None:
where = config.where
else:
where = {}
if self.config.id is not None:
where.update({"app_id": self.config.id})
contents = self.db.query(
input_query=input_query,
n_results=config.number_documents,
@@ -311,7 +322,7 @@ class EmbedChain(JSONSerializable):
logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
def query(self, input_query, config: QueryConfig = None, dry_run=False):
def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -326,6 +337,7 @@ class EmbedChain(JSONSerializable):
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
if config is None:
@@ -336,7 +348,7 @@ class EmbedChain(JSONSerializable):
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
contexts = self.retrieve_from_database(input_query, config)
contexts = self.retrieve_from_database(input_query, config, where)
prompt = self.generate_prompt(input_query, contexts, config, **k)
logging.info(f"Prompt: {prompt}")
@@ -362,7 +374,7 @@ class EmbedChain(JSONSerializable):
yield chunk
logging.info(f"Answer: {streamed_answer}")
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -378,6 +390,7 @@ class EmbedChain(JSONSerializable):
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
if config is None:
@@ -388,7 +401,7 @@ class EmbedChain(JSONSerializable):
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
contexts = self.retrieve_from_database(input_query, config)
contexts = self.retrieve_from_database(input_query, config, where)
chat_history = self.memory.load_memory_variables({})["history"]

View File

@@ -1,9 +1,9 @@
import os
import unittest
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from embedchain import App
from embedchain.config import AppConfig
from embedchain.config import AppConfig, ChatConfig
class TestApp(unittest.TestCase):
@@ -35,3 +35,65 @@ class TestApp(unittest.TestCase):
second_answer = app.chat("Test query 2")
self.assertEqual(second_answer, "Test answer")
self.assertEqual(len(app.memory.chat_memory.messages), 4)
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_chat_with_where_in_params(self):
"""
This test checks the functionality of the 'chat' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'chat' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'chat' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.chat("Test chat", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_chat_with_where_in_chat_config(self):
"""
This test checks the functionality of the 'chat' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified
in the QueryConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'chat' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
chatConfig = ChatConfig(where={"attribute": "value"})
answer = self.app.chat("Test chat", chatConfig)
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig)
mock_answer.assert_called_once()

View File

@@ -73,3 +73,65 @@ class TestApp(unittest.TestCase):
messages_arg = mock_create.call_args.kwargs["messages"]
self.assertEqual(messages_arg[0]["role"], "system")
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.query("Test query", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_query_config(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
queryConfig = QueryConfig(where={"attribute": "value"})
answer = self.app.query("Test query", queryConfig)
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
mock_answer.assert_called_once()