[Improvements] Package improvements (#993)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-05 23:42:45 -08:00
committed by GitHub
parent 1d4e00ccef
commit 51b4966801
13 changed files with 96 additions and 40 deletions

View File

@@ -133,7 +133,9 @@ class EmbedChain(JSONSerializable):
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
dry_run=False,
**kwargs: Dict[str, Any],
loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
**kwargs: Optional[Dict[str, Any]],
):
"""
Adds the data from the given URL to the vector db.
@@ -192,9 +194,9 @@ class EmbedChain(JSONSerializable):
self.user_asks.append([source, data_type.value, metadata])
data_formatter = DataFormatter(data_type, config, kwargs)
data_formatter = DataFormatter(data_type, config, loader, chunker)
documents, metadatas, _ids, new_chunks = self._load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs
)
if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True
@@ -238,7 +240,7 @@ class EmbedChain(JSONSerializable):
data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
**kwargs: Dict[str, Any],
**kwargs: Optional[Dict[str, Any]],
):
"""
Adds the data from the given URL to the vector db.
@@ -269,7 +271,7 @@ class EmbedChain(JSONSerializable):
data_type=data_type,
metadata=metadata,
config=config,
kwargs=kwargs,
**kwargs,
)
def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
@@ -338,6 +340,7 @@ class EmbedChain(JSONSerializable):
metadata: Optional[Dict[str, Any]] = None,
source_hash: Optional[str] = None,
dry_run=False,
**kwargs: Optional[Dict[str, Any]],
):
"""
Loads the data from the given URL, chunks it, and adds it to database.
@@ -431,6 +434,7 @@ class EmbedChain(JSONSerializable):
metadatas=metadatas,
ids=ids,
skip_embedding=(chunker.data_type == DataType.IMAGES),
**kwargs,
)
count_new_chunks = self.db.count() - chunks_before_addition
@@ -448,7 +452,12 @@ class EmbedChain(JSONSerializable):
]
def _retrieve_from_database(
self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
self,
input_query: str,
config: Optional[BaseLlmConfig] = None,
where=None,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
Queries the vector database based on the given input query.
@@ -492,6 +501,7 @@ class EmbedChain(JSONSerializable):
where=where,
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
citations=citations,
**kwargs,
)
return contexts
@@ -526,9 +536,13 @@ class EmbedChain(JSONSerializable):
or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
"""
citations = kwargs.get("citations", False)
if "citations" in kwargs:
citations = kwargs.pop("citations")
else:
citations = False
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations
input_query=input_query, config=config, where=where, citations=citations, **kwargs
)
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
@@ -579,9 +593,13 @@ class EmbedChain(JSONSerializable):
or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
"""
citations = kwargs.get("citations", False)
if "citations" in kwargs:
citations = kwargs.pop("citations")
else:
citations = False
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations
input_query=input_query, config=config, where=where, citations=citations, **kwargs
)
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))