[Improvements] Package improvements (#993)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -133,7 +133,9 @@ class EmbedChain(JSONSerializable):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[AddConfig] = None,
|
||||
dry_run=False,
|
||||
**kwargs: Dict[str, Any],
|
||||
loader: Optional[BaseLoader] = None,
|
||||
chunker: Optional[BaseChunker] = None,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Adds the data from the given URL to the vector db.
|
||||
@@ -192,9 +194,9 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
self.user_asks.append([source, data_type.value, metadata])
|
||||
|
||||
data_formatter = DataFormatter(data_type, config, kwargs)
|
||||
data_formatter = DataFormatter(data_type, config, loader, chunker)
|
||||
documents, metadatas, _ids, new_chunks = self._load_and_embed(
|
||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
|
||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs
|
||||
)
|
||||
if data_type in {DataType.DOCS_SITE}:
|
||||
self.is_docs_site_instance = True
|
||||
@@ -238,7 +240,7 @@ class EmbedChain(JSONSerializable):
|
||||
data_type: Optional[DataType] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[AddConfig] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Adds the data from the given URL to the vector db.
|
||||
@@ -269,7 +271,7 @@ class EmbedChain(JSONSerializable):
|
||||
data_type=data_type,
|
||||
metadata=metadata,
|
||||
config=config,
|
||||
kwargs=kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
|
||||
@@ -338,6 +340,7 @@ class EmbedChain(JSONSerializable):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
source_hash: Optional[str] = None,
|
||||
dry_run=False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Loads the data from the given URL, chunks it, and adds it to database.
|
||||
@@ -431,6 +434,7 @@ class EmbedChain(JSONSerializable):
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
skip_embedding=(chunker.data_type == DataType.IMAGES),
|
||||
**kwargs,
|
||||
)
|
||||
count_new_chunks = self.db.count() - chunks_before_addition
|
||||
|
||||
@@ -448,7 +452,12 @@ class EmbedChain(JSONSerializable):
|
||||
]
|
||||
|
||||
def _retrieve_from_database(
|
||||
self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
|
||||
self,
|
||||
input_query: str,
|
||||
config: Optional[BaseLlmConfig] = None,
|
||||
where=None,
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
@@ -492,6 +501,7 @@ class EmbedChain(JSONSerializable):
|
||||
where=where,
|
||||
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
|
||||
citations=citations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return contexts
|
||||
@@ -526,9 +536,13 @@ class EmbedChain(JSONSerializable):
|
||||
or the dry run result
|
||||
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
||||
"""
|
||||
citations = kwargs.get("citations", False)
|
||||
if "citations" in kwargs:
|
||||
citations = kwargs.pop("citations")
|
||||
else:
|
||||
citations = False
|
||||
|
||||
contexts = self._retrieve_from_database(
|
||||
input_query=input_query, config=config, where=where, citations=citations
|
||||
input_query=input_query, config=config, where=where, citations=citations, **kwargs
|
||||
)
|
||||
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
||||
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
||||
@@ -579,9 +593,13 @@ class EmbedChain(JSONSerializable):
|
||||
or the dry run result
|
||||
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
||||
"""
|
||||
citations = kwargs.get("citations", False)
|
||||
if "citations" in kwargs:
|
||||
citations = kwargs.pop("citations")
|
||||
else:
|
||||
citations = False
|
||||
|
||||
contexts = self._retrieve_from_database(
|
||||
input_query=input_query, config=config, where=where, citations=citations
|
||||
input_query=input_query, config=config, where=where, citations=citations, **kwargs
|
||||
)
|
||||
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
||||
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
||||
|
||||
Reference in New Issue
Block a user