From 79efa51941e5300f35febe594e33f69005519db1 Mon Sep 17 00:00:00 2001 From: cachho Date: Wed, 13 Sep 2023 19:19:48 +0200 Subject: [PATCH] fix: url metadata for all datatypes (#613) --- embedchain/embedchain.py | 58 +++++++++++++++++++++++----- embedchain/loaders/local_qna_pair.py | 4 +- embedchain/models/data_type.py | 38 ++++++++++++++++-- tests/models/test_data_type.py | 32 +++++++++++++++ 4 files changed, 116 insertions(+), 16 deletions(-) create mode 100644 tests/models/test_data_type.py diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index d6364b5b..dbc5d7c9 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -21,7 +21,8 @@ from embedchain.embedder.base import BaseEmbedder from embedchain.helper.json_serializable import JSONSerializable from embedchain.llm.base import BaseLlm from embedchain.loaders.base_loader import BaseLoader -from embedchain.models.data_type import DataType +from embedchain.models.data_type import (DataType, DirectDataType, + IndirectDataType, SpecialDataType) from embedchain.utils import detect_datatype from embedchain.vectordb.base import BaseVectorDB @@ -339,16 +340,53 @@ class EmbedChain(JSONSerializable): :param source_id: Hexadecimal hash of the source. :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks """ - existing_embeddings_data = self.db.get( - where={ - "url": src, - }, - limit=1, - ) - try: - existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"] - except Exception: + # Find existing embeddings for the source + # Depending on the data type, existing embeddings are checked for. + if chunker.data_type.value in [item.value for item in DirectDataType]: + # DirectDataTypes can't be updated. + # Think of a text: + # Either it's the same, then it won't change, so it's not an update. + # Or it's different, then it will be added as a new text. existing_doc_id = None + elif chunker.data_type.value in [item.value for item in IndirectDataType]: + # These types have a indirect source reference + # As long as the reference is the same, they can be updated. + existing_embeddings_data = self.db.get( + where={ + "url": src, + }, + limit=1, + ) + try: + existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"] + except Exception: + existing_doc_id = None + elif chunker.data_type.value in [item.value for item in SpecialDataType]: + # These types don't contain indirect references. + # Through custom logic, they can be attributed to a source and be updated. + if chunker.data_type == DataType.QNA_PAIR: + # QNA_PAIRs update the answer if the question already exists. + existing_embeddings_data = self.db.get( + where={ + "question": src[0], + }, + limit=1, + ) + try: + existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"] + except Exception: + existing_doc_id = None + else: + raise NotImplementedError( + f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data" + ) + else: + raise TypeError( + f"{chunker.data_type} is type {type(chunker.data_type)}. " + "When it should be DirectDataType, IndirectDataType or SpecialDataType." + ) + + # Create chunks embeddings_data = chunker.create_chunks(loader, src) # spread chunking results diff --git a/embedchain/loaders/local_qna_pair.py b/embedchain/loaders/local_qna_pair.py index 61d9a576..ffaa6fea 100644 --- a/embedchain/loaders/local_qna_pair.py +++ b/embedchain/loaders/local_qna_pair.py @@ -11,9 +11,7 @@ class LocalQnaPairLoader(BaseLoader): question, answer = content content = f"Q: {question}\nA: {answer}" url = "local" - meta_data = { - "url": url, - } + meta_data = {"url": url, "question": question} doc_id = hashlib.sha256((content + url).encode()).hexdigest() return { "doc_id": doc_id, diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index 8c13eb06..d41bf0bd 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -1,15 +1,47 @@ from enum import Enum -class DataType(Enum): +class DirectDataType(Enum): + """ + DirectDataType enum contains data types that contain raw data directly. + """ + + TEXT = "text" + + +class IndirectDataType(Enum): + """ + IndirectDataType enum contains data types that contain references to data stored elsewhere. + """ + YOUTUBE_VIDEO = "youtube_video" PDF_FILE = "pdf_file" WEB_PAGE = "web_page" SITEMAP = "sitemap" DOCX = "docx" DOCS_SITE = "docs_site" - TEXT = "text" - QNA_PAIR = "qna_pair" NOTION = "notion" CSV = "csv" MDX = "mdx" + + +class SpecialDataType(Enum): + """ + SpecialDataType enum contains data types that are neither direct nor indirect, or simply require special attention. + """ + + QNA_PAIR = "qna_pair" + + +class DataType(Enum): + TEXT = DirectDataType.TEXT.value + YOUTUBE_VIDEO = IndirectDataType.YOUTUBE_VIDEO.value + PDF_FILE = IndirectDataType.PDF_FILE.value + WEB_PAGE = IndirectDataType.WEB_PAGE.value + SITEMAP = IndirectDataType.SITEMAP.value + DOCX = IndirectDataType.DOCX.value + DOCS_SITE = IndirectDataType.DOCS_SITE.value + NOTION = IndirectDataType.NOTION.value + CSV = IndirectDataType.CSV.value + MDX = IndirectDataType.MDX.value + QNA_PAIR = SpecialDataType.QNA_PAIR.value diff --git a/tests/models/test_data_type.py b/tests/models/test_data_type.py new file mode 100644 index 00000000..7b2f173e --- /dev/null +++ b/tests/models/test_data_type.py @@ -0,0 +1,32 @@ +import unittest + +from embedchain.models.data_type import (DataType, DirectDataType, + IndirectDataType, SpecialDataType) + + +class TestDataTypeEnums(unittest.TestCase): + def test_subclass_types_in_data_type(self): + """Test that all data type category subclasses are contained in the composite data type""" + # Check if DirectDataType values are in DataType + for data_type in DirectDataType: + self.assertIn(data_type.value, DataType._value2member_map_) + + # Check if IndirectDataType values are in DataType + for data_type in IndirectDataType: + self.assertIn(data_type.value, DataType._value2member_map_) + + # Check if SpecialDataType values are in DataType + for data_type in SpecialDataType: + self.assertIn(data_type.value, DataType._value2member_map_) + + def test_data_type_in_subclasses(self): + """Test that all data types in the composite data type are categorized in a subclass""" + for data_type in DataType: + if data_type.value in DirectDataType._value2member_map_: + self.assertIn(data_type.value, DirectDataType._value2member_map_) + elif data_type.value in IndirectDataType._value2member_map_: + self.assertIn(data_type.value, IndirectDataType._value2member_map_) + elif data_type.value in SpecialDataType._value2member_map_: + self.assertIn(data_type.value, SpecialDataType._value2member_map_) + else: + self.fail(f"{data_type.value} not found in any subclass enums")