fix: url metadata for all datatypes (#613)

This commit is contained in:
cachho
2023-09-13 19:19:48 +02:00
committed by GitHub
parent 701d0b21ef
commit 79efa51941
4 changed files with 116 additions and 16 deletions

View File

@@ -21,7 +21,8 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.utils import detect_datatype
from embedchain.vectordb.base import BaseVectorDB
@@ -339,16 +340,53 @@ class EmbedChain(JSONSerializable):
:param source_id: Hexadecimal hash of the source.
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
"""
existing_embeddings_data = self.db.get(
where={
"url": src,
},
limit=1,
)
try:
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
except Exception:
# Find existing embeddings for the source
# Depending on the data type, existing embeddings are checked for.
if chunker.data_type.value in [item.value for item in DirectDataType]:
# DirectDataTypes can't be updated.
# Think of a text:
# Either it's the same, then it won't change, so it's not an update.
# Or it's different, then it will be added as a new text.
existing_doc_id = None
elif chunker.data_type.value in [item.value for item in IndirectDataType]:
# These types have a indirect source reference
# As long as the reference is the same, they can be updated.
existing_embeddings_data = self.db.get(
where={
"url": src,
},
limit=1,
)
try:
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
except Exception:
existing_doc_id = None
elif chunker.data_type.value in [item.value for item in SpecialDataType]:
# These types don't contain indirect references.
# Through custom logic, they can be attributed to a source and be updated.
if chunker.data_type == DataType.QNA_PAIR:
# QNA_PAIRs update the answer if the question already exists.
existing_embeddings_data = self.db.get(
where={
"question": src[0],
},
limit=1,
)
try:
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
except Exception:
existing_doc_id = None
else:
raise NotImplementedError(
f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data"
)
else:
raise TypeError(
f"{chunker.data_type} is type {type(chunker.data_type)}. "
"When it should be DirectDataType, IndirectDataType or SpecialDataType."
)
# Create chunks
embeddings_data = chunker.create_chunks(loader, src)
# spread chunking results

View File

@@ -11,9 +11,7 @@ class LocalQnaPairLoader(BaseLoader):
question, answer = content
content = f"Q: {question}\nA: {answer}"
url = "local"
meta_data = {
"url": url,
}
meta_data = {"url": url, "question": question}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,

View File

@@ -1,15 +1,47 @@
from enum import Enum
class DataType(Enum):
class DirectDataType(Enum):
"""
DirectDataType enum contains data types that contain raw data directly.
"""
TEXT = "text"
class IndirectDataType(Enum):
"""
IndirectDataType enum contains data types that contain references to data stored elsewhere.
"""
YOUTUBE_VIDEO = "youtube_video"
PDF_FILE = "pdf_file"
WEB_PAGE = "web_page"
SITEMAP = "sitemap"
DOCX = "docx"
DOCS_SITE = "docs_site"
TEXT = "text"
QNA_PAIR = "qna_pair"
NOTION = "notion"
CSV = "csv"
MDX = "mdx"
class SpecialDataType(Enum):
"""
SpecialDataType enum contains data types that are neither direct nor indirect, or simply require special attention.
"""
QNA_PAIR = "qna_pair"
class DataType(Enum):
TEXT = DirectDataType.TEXT.value
YOUTUBE_VIDEO = IndirectDataType.YOUTUBE_VIDEO.value
PDF_FILE = IndirectDataType.PDF_FILE.value
WEB_PAGE = IndirectDataType.WEB_PAGE.value
SITEMAP = IndirectDataType.SITEMAP.value
DOCX = IndirectDataType.DOCX.value
DOCS_SITE = IndirectDataType.DOCS_SITE.value
NOTION = IndirectDataType.NOTION.value
CSV = IndirectDataType.CSV.value
MDX = IndirectDataType.MDX.value
QNA_PAIR = SpecialDataType.QNA_PAIR.value

View File

@@ -0,0 +1,32 @@
import unittest
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
class TestDataTypeEnums(unittest.TestCase):
def test_subclass_types_in_data_type(self):
"""Test that all data type category subclasses are contained in the composite data type"""
# Check if DirectDataType values are in DataType
for data_type in DirectDataType:
self.assertIn(data_type.value, DataType._value2member_map_)
# Check if IndirectDataType values are in DataType
for data_type in IndirectDataType:
self.assertIn(data_type.value, DataType._value2member_map_)
# Check if SpecialDataType values are in DataType
for data_type in SpecialDataType:
self.assertIn(data_type.value, DataType._value2member_map_)
def test_data_type_in_subclasses(self):
"""Test that all data types in the composite data type are categorized in a subclass"""
for data_type in DataType:
if data_type.value in DirectDataType._value2member_map_:
self.assertIn(data_type.value, DirectDataType._value2member_map_)
elif data_type.value in IndirectDataType._value2member_map_:
self.assertIn(data_type.value, IndirectDataType._value2member_map_)
elif data_type.value in SpecialDataType._value2member_map_:
self.assertIn(data_type.value, SpecialDataType._value2member_map_)
else:
self.fail(f"{data_type.value} not found in any subclass enums")