Add dry_run to add() (#545)

This commit is contained in:
Dev Khant
2023-09-12 09:20:31 +05:30
committed by GitHub
parent 79f5a1d052
commit 7c39d9f0c1
4 changed files with 60 additions and 9 deletions

View File

@@ -125,6 +125,7 @@ class EmbedChain(JSONSerializable):
data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
dry_run=False,
):
"""
Adds the data from the given URL to the vector db.
@@ -141,6 +142,8 @@ class EmbedChain(JSONSerializable):
:param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type
:param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended.
deafaults to False
:return: source_id, a md5-hash of the source, in hexadecimal representation.
:rtype: str
"""
@@ -176,12 +179,17 @@ class EmbedChain(JSONSerializable):
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([source, data_type.value, metadata])
documents, _metadatas, _ids, new_chunks = self.load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_id
documents, metadatas, _ids, new_chunks = self.load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
)
if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True
if dry_run:
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
logging.debug(f"Dry run info : {data_chunks_info}")
return data_chunks_info
# Send anonymous telemetry
if self.config.collect_metrics:
# it's quicker to check the variable twice than to count words when they won't be submitted.
@@ -233,6 +241,7 @@ class EmbedChain(JSONSerializable):
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
dry_run = False
) -> Tuple[List[str], Dict[str, Any], List[str], int]:
"""The loader to use to load the data.
@@ -247,6 +256,8 @@ class EmbedChain(JSONSerializable):
:type metadata: Dict[str, Any], optional
:param source_id: Hexadecimal hash of the source., defaults to None
:type source_id: str, optional
:param dry_run: Optional. A dry run returns chunks and doesn't update DB.
:type dry_run: bool, defaults to False
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
:rtype: Tuple[List[str], Dict[str, Any], List[str], int]
"""
@@ -277,6 +288,9 @@ class EmbedChain(JSONSerializable):
ids = list(data_dict.keys())
documents, metadatas = zip(*data_dict.values())
if dry_run:
return list(documents), metadatas, ids, 0
# Loop though all metadatas and add extras.
new_metadatas = []
for m in metadatas: