fix: url metadata for all datatypes (#613)
This commit is contained in:
@@ -21,7 +21,8 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
from embedchain.utils import detect_datatype
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
@@ -339,16 +340,53 @@ class EmbedChain(JSONSerializable):
|
||||
:param source_id: Hexadecimal hash of the source.
|
||||
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||
"""
|
||||
existing_embeddings_data = self.db.get(
|
||||
where={
|
||||
"url": src,
|
||||
},
|
||||
limit=1,
|
||||
)
|
||||
try:
|
||||
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
|
||||
except Exception:
|
||||
# Find existing embeddings for the source
|
||||
# Depending on the data type, existing embeddings are checked for.
|
||||
if chunker.data_type.value in [item.value for item in DirectDataType]:
|
||||
# DirectDataTypes can't be updated.
|
||||
# Think of a text:
|
||||
# Either it's the same, then it won't change, so it's not an update.
|
||||
# Or it's different, then it will be added as a new text.
|
||||
existing_doc_id = None
|
||||
elif chunker.data_type.value in [item.value for item in IndirectDataType]:
|
||||
# These types have a indirect source reference
|
||||
# As long as the reference is the same, they can be updated.
|
||||
existing_embeddings_data = self.db.get(
|
||||
where={
|
||||
"url": src,
|
||||
},
|
||||
limit=1,
|
||||
)
|
||||
try:
|
||||
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
|
||||
except Exception:
|
||||
existing_doc_id = None
|
||||
elif chunker.data_type.value in [item.value for item in SpecialDataType]:
|
||||
# These types don't contain indirect references.
|
||||
# Through custom logic, they can be attributed to a source and be updated.
|
||||
if chunker.data_type == DataType.QNA_PAIR:
|
||||
# QNA_PAIRs update the answer if the question already exists.
|
||||
existing_embeddings_data = self.db.get(
|
||||
where={
|
||||
"question": src[0],
|
||||
},
|
||||
limit=1,
|
||||
)
|
||||
try:
|
||||
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
|
||||
except Exception:
|
||||
existing_doc_id = None
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data"
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{chunker.data_type} is type {type(chunker.data_type)}. "
|
||||
"When it should be DirectDataType, IndirectDataType or SpecialDataType."
|
||||
)
|
||||
|
||||
# Create chunks
|
||||
embeddings_data = chunker.create_chunks(loader, src)
|
||||
|
||||
# spread chunking results
|
||||
|
||||
@@ -11,9 +11,7 @@ class LocalQnaPairLoader(BaseLoader):
|
||||
question, answer = content
|
||||
content = f"Q: {question}\nA: {answer}"
|
||||
url = "local"
|
||||
meta_data = {
|
||||
"url": url,
|
||||
}
|
||||
meta_data = {"url": url, "question": question}
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
|
||||
@@ -1,15 +1,47 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
class DirectDataType(Enum):
|
||||
"""
|
||||
DirectDataType enum contains data types that contain raw data directly.
|
||||
"""
|
||||
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
class IndirectDataType(Enum):
|
||||
"""
|
||||
IndirectDataType enum contains data types that contain references to data stored elsewhere.
|
||||
"""
|
||||
|
||||
YOUTUBE_VIDEO = "youtube_video"
|
||||
PDF_FILE = "pdf_file"
|
||||
WEB_PAGE = "web_page"
|
||||
SITEMAP = "sitemap"
|
||||
DOCX = "docx"
|
||||
DOCS_SITE = "docs_site"
|
||||
TEXT = "text"
|
||||
QNA_PAIR = "qna_pair"
|
||||
NOTION = "notion"
|
||||
CSV = "csv"
|
||||
MDX = "mdx"
|
||||
|
||||
|
||||
class SpecialDataType(Enum):
|
||||
"""
|
||||
SpecialDataType enum contains data types that are neither direct nor indirect, or simply require special attention.
|
||||
"""
|
||||
|
||||
QNA_PAIR = "qna_pair"
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
TEXT = DirectDataType.TEXT.value
|
||||
YOUTUBE_VIDEO = IndirectDataType.YOUTUBE_VIDEO.value
|
||||
PDF_FILE = IndirectDataType.PDF_FILE.value
|
||||
WEB_PAGE = IndirectDataType.WEB_PAGE.value
|
||||
SITEMAP = IndirectDataType.SITEMAP.value
|
||||
DOCX = IndirectDataType.DOCX.value
|
||||
DOCS_SITE = IndirectDataType.DOCS_SITE.value
|
||||
NOTION = IndirectDataType.NOTION.value
|
||||
CSV = IndirectDataType.CSV.value
|
||||
MDX = IndirectDataType.MDX.value
|
||||
QNA_PAIR = SpecialDataType.QNA_PAIR.value
|
||||
|
||||
32
tests/models/test_data_type.py
Normal file
32
tests/models/test_data_type.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import unittest
|
||||
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
|
||||
|
||||
class TestDataTypeEnums(unittest.TestCase):
|
||||
def test_subclass_types_in_data_type(self):
|
||||
"""Test that all data type category subclasses are contained in the composite data type"""
|
||||
# Check if DirectDataType values are in DataType
|
||||
for data_type in DirectDataType:
|
||||
self.assertIn(data_type.value, DataType._value2member_map_)
|
||||
|
||||
# Check if IndirectDataType values are in DataType
|
||||
for data_type in IndirectDataType:
|
||||
self.assertIn(data_type.value, DataType._value2member_map_)
|
||||
|
||||
# Check if SpecialDataType values are in DataType
|
||||
for data_type in SpecialDataType:
|
||||
self.assertIn(data_type.value, DataType._value2member_map_)
|
||||
|
||||
def test_data_type_in_subclasses(self):
|
||||
"""Test that all data types in the composite data type are categorized in a subclass"""
|
||||
for data_type in DataType:
|
||||
if data_type.value in DirectDataType._value2member_map_:
|
||||
self.assertIn(data_type.value, DirectDataType._value2member_map_)
|
||||
elif data_type.value in IndirectDataType._value2member_map_:
|
||||
self.assertIn(data_type.value, IndirectDataType._value2member_map_)
|
||||
elif data_type.value in SpecialDataType._value2member_map_:
|
||||
self.assertIn(data_type.value, SpecialDataType._value2member_map_)
|
||||
else:
|
||||
self.fail(f"{data_type.value} not found in any subclass enums")
|
||||
Reference in New Issue
Block a user