From 3d0e4141bf34e97cb0c056bf3b7dd554efb146c1 Mon Sep 17 00:00:00 2001 From: cachho Date: Sun, 17 Sep 2023 19:52:12 +0200 Subject: [PATCH] refactor: get existing doc id method (#616) --- embedchain/embedchain.py | 96 +++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 45 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 03fcdeeb..d60a8bd2 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -322,6 +322,56 @@ class EmbedChain(JSONSerializable): count_new_chunks = self.db.count() - chunks_before_addition print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) return list(documents), metadatas, ids, count_new_chunks + + def _get_existing_doc_id(self, chunker: BaseChunker, src: Any): + """ + Get id of existing document for a given source, based on the data type + """ + # Find existing embeddings for the source + # Depending on the data type, existing embeddings are checked for. + if chunker.data_type.value in [item.value for item in DirectDataType]: + # DirectDataTypes can't be updated. + # Think of a text: + # Either it's the same, then it won't change, so it's not an update. + # Or it's different, then it will be added as a new text. + return None + elif chunker.data_type.value in [item.value for item in IndirectDataType]: + # These types have a indirect source reference + # As long as the reference is the same, they can be updated. + existing_embeddings_data = self.db.get( + where={ + "url": src, + }, + limit=1, + ) + if len(existing_embeddings_data.get("metadatas", [])) > 0: + return existing_embeddings_data["metadatas"][0]["doc_id"] + else: + return None + elif chunker.data_type.value in [item.value for item in SpecialDataType]: + # These types don't contain indirect references. + # Through custom logic, they can be attributed to a source and be updated. + if chunker.data_type == DataType.QNA_PAIR: + # QNA_PAIRs update the answer if the question already exists. + existing_embeddings_data = self.db.get( + where={ + "question": src[0], + }, + limit=1, + ) + if len(existing_embeddings_data.get("metadatas", [])) > 0: + return existing_embeddings_data["metadatas"][0]["doc_id"] + else: + return None + else: + raise NotImplementedError( + f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data" + ) + else: + raise TypeError( + f"{chunker.data_type} is type {type(chunker.data_type)}. " + "When it should be DirectDataType, IndirectDataType or SpecialDataType." + ) def load_and_embed_v2( self, @@ -343,51 +393,7 @@ class EmbedChain(JSONSerializable): :param source_id: Hexadecimal hash of the source. :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks """ - # Find existing embeddings for the source - # Depending on the data type, existing embeddings are checked for. - if chunker.data_type.value in [item.value for item in DirectDataType]: - # DirectDataTypes can't be updated. - # Think of a text: - # Either it's the same, then it won't change, so it's not an update. - # Or it's different, then it will be added as a new text. - existing_doc_id = None - elif chunker.data_type.value in [item.value for item in IndirectDataType]: - # These types have a indirect source reference - # As long as the reference is the same, they can be updated. - existing_embeddings_data = self.db.get( - where={ - "url": src, - }, - limit=1, - ) - try: - existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"] - except Exception: - existing_doc_id = None - elif chunker.data_type.value in [item.value for item in SpecialDataType]: - # These types don't contain indirect references. - # Through custom logic, they can be attributed to a source and be updated. - if chunker.data_type == DataType.QNA_PAIR: - # QNA_PAIRs update the answer if the question already exists. - existing_embeddings_data = self.db.get( - where={ - "question": src[0], - }, - limit=1, - ) - try: - existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"] - except Exception: - existing_doc_id = None - else: - raise NotImplementedError( - f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data" - ) - else: - raise TypeError( - f"{chunker.data_type} is type {type(chunker.data_type)}. " - "When it should be DirectDataType, IndirectDataType or SpecialDataType." - ) + existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src) # Create chunks embeddings_data = chunker.create_chunks(loader, src)