feat: where filter in vector database (#518)

This commit is contained in:
sw8fbar
2023-09-04 15:49:59 -05:00
committed by GitHub
parent 202fd2d5b6
commit 3e66ddf69a
6 changed files with 156 additions and 14 deletions

View File

@@ -250,16 +250,27 @@ class EmbedChain(JSONSerializable):
"""
raise NotImplementedError
def retrieve_from_database(self, input_query, config: QueryConfig):
def retrieve_from_database(self, input_query, config: QueryConfig, where=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query
:param input_query: The query to use.
:param config: The query configuration.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The content of the document that matched your query.
"""
where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter
if where is not None:
where = where
elif config is not None and config.where is not None:
where = config.where
else:
where = {}
if self.config.id is not None:
where.update({"app_id": self.config.id})
contents = self.db.query(
input_query=input_query,
n_results=config.number_documents,
@@ -311,7 +322,7 @@ class EmbedChain(JSONSerializable):
logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
def query(self, input_query, config: QueryConfig = None, dry_run=False):
def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -326,6 +337,7 @@ class EmbedChain(JSONSerializable):
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
if config is None:
@@ -336,7 +348,7 @@ class EmbedChain(JSONSerializable):
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
contexts = self.retrieve_from_database(input_query, config)
contexts = self.retrieve_from_database(input_query, config, where)
prompt = self.generate_prompt(input_query, contexts, config, **k)
logging.info(f"Prompt: {prompt}")
@@ -362,7 +374,7 @@ class EmbedChain(JSONSerializable):
yield chunk
logging.info(f"Answer: {streamed_answer}")
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -378,6 +390,7 @@ class EmbedChain(JSONSerializable):
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
if config is None:
@@ -388,7 +401,7 @@ class EmbedChain(JSONSerializable):
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
contexts = self.retrieve_from_database(input_query, config)
contexts = self.retrieve_from_database(input_query, config, where)
chat_history = self.memory.load_memory_variables({})["history"]