feat: where filter in vector database (#518)
This commit is contained in:
@@ -4,8 +4,7 @@ from embedchain.apps.App import App
|
|||||||
from embedchain.apps.OpenSourceApp import OpenSourceApp
|
from embedchain.apps.OpenSourceApp import OpenSourceApp
|
||||||
from embedchain.config import ChatConfig, QueryConfig
|
from embedchain.config import ChatConfig, QueryConfig
|
||||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||||
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
|
from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
|
||||||
DEFAULT_PROMPT_WITH_HISTORY)
|
|
||||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
@@ -60,12 +59,12 @@ class PersonApp(EmbedChainPersonApp, App):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
||||||
config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
|
config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
|
||||||
return super().query(input_query, config, dry_run)
|
return super().query(input_query, config, dry_run, where=None)
|
||||||
|
|
||||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
|
||||||
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
|
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
|
||||||
return super().chat(input_query, config, dry_run)
|
return super().chat(input_query, config, dry_run, where)
|
||||||
|
|
||||||
|
|
||||||
@register_deserializable
|
@register_deserializable
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class ChatConfig(QueryConfig):
|
|||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
deployment_name=None,
|
deployment_name=None,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
|
where=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the ChatConfig instance.
|
Initializes the ChatConfig instance.
|
||||||
@@ -57,6 +58,7 @@ class ChatConfig(QueryConfig):
|
|||||||
:param stream: Optional. Control if response is streamed back to the user
|
:param stream: Optional. Control if response is streamed back to the user
|
||||||
:param deployment_name: t.b.a.
|
:param deployment_name: t.b.a.
|
||||||
:param system_prompt: Optional. System prompt string.
|
:param system_prompt: Optional. System prompt string.
|
||||||
|
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||||
:raises ValueError: If the template is not valid as template should contain
|
:raises ValueError: If the template is not valid as template should contain
|
||||||
$context and $query and $history
|
$context and $query and $history
|
||||||
"""
|
"""
|
||||||
@@ -77,6 +79,7 @@ class ChatConfig(QueryConfig):
|
|||||||
stream=stream,
|
stream=stream,
|
||||||
deployment_name=deployment_name,
|
deployment_name=deployment_name,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
where=where,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_history(self, history):
|
def set_history(self, history):
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ class QueryConfig(BaseConfig):
|
|||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
deployment_name=None,
|
deployment_name=None,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
|
where=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the QueryConfig instance.
|
Initializes the QueryConfig instance.
|
||||||
@@ -87,6 +88,7 @@ class QueryConfig(BaseConfig):
|
|||||||
:param stream: Optional. Control if response is streamed back to user
|
:param stream: Optional. Control if response is streamed back to user
|
||||||
:param deployment_name: t.b.a.
|
:param deployment_name: t.b.a.
|
||||||
:param system_prompt: Optional. System prompt string.
|
:param system_prompt: Optional. System prompt string.
|
||||||
|
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||||
:raises ValueError: If the template is not valid as template should
|
:raises ValueError: If the template is not valid as template should
|
||||||
contain $context and $query (and optionally $history).
|
contain $context and $query (and optionally $history).
|
||||||
"""
|
"""
|
||||||
@@ -127,6 +129,7 @@ class QueryConfig(BaseConfig):
|
|||||||
if not isinstance(stream, bool):
|
if not isinstance(stream, bool):
|
||||||
raise ValueError("`stream` should be bool")
|
raise ValueError("`stream` should be bool")
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
self.where = where
|
||||||
|
|
||||||
def validate_template(self, template: Template):
|
def validate_template(self, template: Template):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -250,16 +250,27 @@ class EmbedChain(JSONSerializable):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def retrieve_from_database(self, input_query, config: QueryConfig):
|
def retrieve_from_database(self, input_query, config: QueryConfig, where=None):
|
||||||
"""
|
"""
|
||||||
Queries the vector database based on the given input query.
|
Queries the vector database based on the given input query.
|
||||||
Gets relevant doc based on the query
|
Gets relevant doc based on the query
|
||||||
|
|
||||||
:param input_query: The query to use.
|
:param input_query: The query to use.
|
||||||
:param config: The query configuration.
|
:param config: The query configuration.
|
||||||
|
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||||
:return: The content of the document that matched your query.
|
:return: The content of the document that matched your query.
|
||||||
"""
|
"""
|
||||||
where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter
|
|
||||||
|
if where is not None:
|
||||||
|
where = where
|
||||||
|
elif config is not None and config.where is not None:
|
||||||
|
where = config.where
|
||||||
|
else:
|
||||||
|
where = {}
|
||||||
|
|
||||||
|
if self.config.id is not None:
|
||||||
|
where.update({"app_id": self.config.id})
|
||||||
|
|
||||||
contents = self.db.query(
|
contents = self.db.query(
|
||||||
input_query=input_query,
|
input_query=input_query,
|
||||||
n_results=config.number_documents,
|
n_results=config.number_documents,
|
||||||
@@ -311,7 +322,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
logging.info(f"Access search to get answers for {input_query}")
|
logging.info(f"Access search to get answers for {input_query}")
|
||||||
return search.run(input_query)
|
return search.run(input_query)
|
||||||
|
|
||||||
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None):
|
||||||
"""
|
"""
|
||||||
Queries the vector database based on the given input query.
|
Queries the vector database based on the given input query.
|
||||||
Gets relevant doc based on the query and then passes it to an
|
Gets relevant doc based on the query and then passes it to an
|
||||||
@@ -326,6 +337,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
by the vector database's doc retrieval.
|
by the vector database's doc retrieval.
|
||||||
The only thing the dry run does not consider is the cut-off due to
|
The only thing the dry run does not consider is the cut-off due to
|
||||||
the `max_tokens` parameter.
|
the `max_tokens` parameter.
|
||||||
|
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||||
:return: The answer to the query.
|
:return: The answer to the query.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
@@ -336,7 +348,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
k = {}
|
k = {}
|
||||||
if self.online:
|
if self.online:
|
||||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||||
contexts = self.retrieve_from_database(input_query, config)
|
contexts = self.retrieve_from_database(input_query, config, where)
|
||||||
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
||||||
logging.info(f"Prompt: {prompt}")
|
logging.info(f"Prompt: {prompt}")
|
||||||
|
|
||||||
@@ -362,7 +374,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
yield chunk
|
yield chunk
|
||||||
logging.info(f"Answer: {streamed_answer}")
|
logging.info(f"Answer: {streamed_answer}")
|
||||||
|
|
||||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
|
||||||
"""
|
"""
|
||||||
Queries the vector database on the given input query.
|
Queries the vector database on the given input query.
|
||||||
Gets relevant doc based on the query and then passes it to an
|
Gets relevant doc based on the query and then passes it to an
|
||||||
@@ -378,6 +390,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
by the vector database's doc retrieval.
|
by the vector database's doc retrieval.
|
||||||
The only thing the dry run does not consider is the cut-off due to
|
The only thing the dry run does not consider is the cut-off due to
|
||||||
the `max_tokens` parameter.
|
the `max_tokens` parameter.
|
||||||
|
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||||
:return: The answer to the query.
|
:return: The answer to the query.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
@@ -388,7 +401,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
k = {}
|
k = {}
|
||||||
if self.online:
|
if self.online:
|
||||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||||
contexts = self.retrieve_from_database(input_query, config)
|
contexts = self.retrieve_from_database(input_query, config, where)
|
||||||
|
|
||||||
chat_history = self.memory.load_memory_variables({})["history"]
|
chat_history = self.memory.load_memory_variables({})["history"]
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig
|
from embedchain.config import AppConfig, ChatConfig
|
||||||
|
|
||||||
|
|
||||||
class TestApp(unittest.TestCase):
|
class TestApp(unittest.TestCase):
|
||||||
@@ -35,3 +35,65 @@ class TestApp(unittest.TestCase):
|
|||||||
second_answer = app.chat("Test query 2")
|
second_answer = app.chat("Test query 2")
|
||||||
self.assertEqual(second_answer, "Test answer")
|
self.assertEqual(second_answer, "Test answer")
|
||||||
self.assertEqual(len(app.memory.chat_memory.messages), 4)
|
self.assertEqual(len(app.memory.chat_memory.messages), 4)
|
||||||
|
|
||||||
|
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||||
|
def test_chat_with_where_in_params(self):
|
||||||
|
"""
|
||||||
|
This test checks the functionality of the 'chat' method in the App class.
|
||||||
|
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||||
|
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||||
|
|
||||||
|
The 'chat' method is expected to call 'retrieve_from_database' with the where filter and
|
||||||
|
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||||
|
|
||||||
|
Key assumptions tested:
|
||||||
|
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||||
|
QueryConfig.
|
||||||
|
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||||
|
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||||
|
|
||||||
|
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
|
||||||
|
'get_llm_model_answer' methods.
|
||||||
|
"""
|
||||||
|
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||||
|
mock_retrieve.return_value = ["Test context"]
|
||||||
|
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||||
|
mock_answer.return_value = "Test answer"
|
||||||
|
answer = self.app.chat("Test chat", where={"attribute": "value"})
|
||||||
|
|
||||||
|
self.assertEqual(answer, "Test answer")
|
||||||
|
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
|
||||||
|
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
|
||||||
|
mock_answer.assert_called_once()
|
||||||
|
|
||||||
|
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||||
|
def test_chat_with_where_in_chat_config(self):
|
||||||
|
"""
|
||||||
|
This test checks the functionality of the 'chat' method in the App class.
|
||||||
|
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||||
|
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||||
|
|
||||||
|
The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified
|
||||||
|
in the QueryConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
|
||||||
|
|
||||||
|
Key assumptions tested:
|
||||||
|
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||||
|
QueryConfig.
|
||||||
|
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||||
|
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||||
|
|
||||||
|
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
|
||||||
|
'get_llm_model_answer' methods.
|
||||||
|
"""
|
||||||
|
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||||
|
mock_retrieve.return_value = ["Test context"]
|
||||||
|
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||||
|
mock_answer.return_value = "Test answer"
|
||||||
|
chatConfig = ChatConfig(where={"attribute": "value"})
|
||||||
|
answer = self.app.chat("Test chat", chatConfig)
|
||||||
|
|
||||||
|
self.assertEqual(answer, "Test answer")
|
||||||
|
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
|
||||||
|
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
|
||||||
|
self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig)
|
||||||
|
mock_answer.assert_called_once()
|
||||||
|
|||||||
@@ -73,3 +73,65 @@ class TestApp(unittest.TestCase):
|
|||||||
messages_arg = mock_create.call_args.kwargs["messages"]
|
messages_arg = mock_create.call_args.kwargs["messages"]
|
||||||
self.assertEqual(messages_arg[0]["role"], "system")
|
self.assertEqual(messages_arg[0]["role"], "system")
|
||||||
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
|
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
|
||||||
|
|
||||||
|
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||||
|
def test_query_with_where_in_params(self):
|
||||||
|
"""
|
||||||
|
This test checks the functionality of the 'query' method in the App class.
|
||||||
|
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||||
|
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||||
|
|
||||||
|
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
|
||||||
|
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||||
|
|
||||||
|
Key assumptions tested:
|
||||||
|
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||||
|
QueryConfig.
|
||||||
|
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||||
|
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||||
|
|
||||||
|
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||||
|
'get_llm_model_answer' methods.
|
||||||
|
"""
|
||||||
|
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||||
|
mock_retrieve.return_value = ["Test context"]
|
||||||
|
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||||
|
mock_answer.return_value = "Test answer"
|
||||||
|
answer = self.app.query("Test query", where={"attribute": "value"})
|
||||||
|
|
||||||
|
self.assertEqual(answer, "Test answer")
|
||||||
|
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||||
|
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
|
||||||
|
mock_answer.assert_called_once()
|
||||||
|
|
||||||
|
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||||
|
def test_query_with_where_in_query_config(self):
|
||||||
|
"""
|
||||||
|
This test checks the functionality of the 'query' method in the App class.
|
||||||
|
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||||
|
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||||
|
|
||||||
|
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
|
||||||
|
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||||
|
|
||||||
|
Key assumptions tested:
|
||||||
|
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||||
|
QueryConfig.
|
||||||
|
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||||
|
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||||
|
|
||||||
|
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||||
|
'get_llm_model_answer' methods.
|
||||||
|
"""
|
||||||
|
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||||
|
mock_retrieve.return_value = ["Test context"]
|
||||||
|
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||||
|
mock_answer.return_value = "Test answer"
|
||||||
|
queryConfig = QueryConfig(where={"attribute": "value"})
|
||||||
|
answer = self.app.query("Test query", queryConfig)
|
||||||
|
|
||||||
|
self.assertEqual(answer, "Test answer")
|
||||||
|
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||||
|
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
|
||||||
|
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
|
||||||
|
mock_answer.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user