diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index cbb1dfb6..9374743b 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -313,9 +313,6 @@ class EmbedChain(JSONSerializable): ids = list(data_dict.keys()) documents, metadatas = zip(*data_dict.values()) - if dry_run: - return list(documents), metadatas, ids, 0 - # Loop though all metadatas and add extras. new_metadatas = [] for m in metadatas: @@ -334,6 +331,9 @@ class EmbedChain(JSONSerializable): new_metadatas.append(m) metadatas = new_metadatas + if dry_run: + return list(documents), metadatas, ids, 0 + # Count before, to calculate a delta in the end. chunks_before_addition = self.db.count() @@ -410,6 +410,8 @@ class EmbedChain(JSONSerializable): remote sources or local content for local loaders. :param metadata: Optional. Metadata associated with the data source. :param source_id: Hexadecimal hash of the source. + :param dry_run: Optional. A dry run returns chunks and doesn't update DB. + :type dry_run: bool, defaults to False :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks """ existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src) @@ -474,6 +476,9 @@ class EmbedChain(JSONSerializable): new_metadatas.append(m) metadatas = new_metadatas + if dry_run: + return list(documents), metadatas, ids, 0 + # Count before, to calculate a delta in the end. chunks_before_addition = self.count()