From 4a5ed1dd8d7ec5f702da6e1d009b4577444afb46 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Wed, 6 Dec 2023 15:23:28 -0800 Subject: [PATCH] Update ec query and chat function (#996) Co-authored-by: Deven Patel --- embedchain/embedchain.py | 10 ++++------ pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 47fd2aa4..cc2982c5 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -512,6 +512,7 @@ class EmbedChain(JSONSerializable): config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None, + citations: bool = False, **kwargs: Dict[str, Any], ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]: """ @@ -536,10 +537,8 @@ class EmbedChain(JSONSerializable): or the dry run result :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] """ - citations = kwargs.get("citations", False) - db_kwargs = {key: value for key, value in kwargs.items() if key != "citations"} contexts = self._retrieve_from_database( - input_query=input_query, config=config, where=where, citations=citations, **db_kwargs + input_query=input_query, config=config, where=where, citations=citations, **kwargs ) if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) @@ -564,6 +563,7 @@ class EmbedChain(JSONSerializable): config: Optional[BaseLlmConfig] = None, dry_run=False, where: Optional[Dict[str, str]] = None, + citations: bool = False, **kwargs: Dict[str, Any], ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]: """ @@ -590,10 +590,8 @@ class EmbedChain(JSONSerializable): or the dry run result :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] """ - citations = kwargs.get("citations", False) - db_kwargs = {key: value for key, value in kwargs.items() if key != "citations"} contexts = self._retrieve_from_database( - input_query=input_query, config=config, where=where, citations=citations, **db_kwargs + input_query=input_query, config=config, where=where, citations=citations, **kwargs ) if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) diff --git a/pyproject.toml b/pyproject.toml index f5d998ff..02f0b73a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.28" +version = "0.1.29" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ",