fix: url metadata for all datatypes (#613)
This commit is contained in:
@@ -21,7 +21,8 @@ from embedchain.embedder.base import BaseEmbedder
|
|||||||
from embedchain.helper.json_serializable import JSONSerializable
|
from embedchain.helper.json_serializable import JSONSerializable
|
||||||
from embedchain.llm.base import BaseLlm
|
from embedchain.llm.base import BaseLlm
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||||
|
IndirectDataType, SpecialDataType)
|
||||||
from embedchain.utils import detect_datatype
|
from embedchain.utils import detect_datatype
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
|
|
||||||
@@ -339,16 +340,53 @@ class EmbedChain(JSONSerializable):
|
|||||||
:param source_id: Hexadecimal hash of the source.
|
:param source_id: Hexadecimal hash of the source.
|
||||||
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||||
"""
|
"""
|
||||||
existing_embeddings_data = self.db.get(
|
# Find existing embeddings for the source
|
||||||
where={
|
# Depending on the data type, existing embeddings are checked for.
|
||||||
"url": src,
|
if chunker.data_type.value in [item.value for item in DirectDataType]:
|
||||||
},
|
# DirectDataTypes can't be updated.
|
||||||
limit=1,
|
# Think of a text:
|
||||||
)
|
# Either it's the same, then it won't change, so it's not an update.
|
||||||
try:
|
# Or it's different, then it will be added as a new text.
|
||||||
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
|
|
||||||
except Exception:
|
|
||||||
existing_doc_id = None
|
existing_doc_id = None
|
||||||
|
elif chunker.data_type.value in [item.value for item in IndirectDataType]:
|
||||||
|
# These types have a indirect source reference
|
||||||
|
# As long as the reference is the same, they can be updated.
|
||||||
|
existing_embeddings_data = self.db.get(
|
||||||
|
where={
|
||||||
|
"url": src,
|
||||||
|
},
|
||||||
|
limit=1,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
|
||||||
|
except Exception:
|
||||||
|
existing_doc_id = None
|
||||||
|
elif chunker.data_type.value in [item.value for item in SpecialDataType]:
|
||||||
|
# These types don't contain indirect references.
|
||||||
|
# Through custom logic, they can be attributed to a source and be updated.
|
||||||
|
if chunker.data_type == DataType.QNA_PAIR:
|
||||||
|
# QNA_PAIRs update the answer if the question already exists.
|
||||||
|
existing_embeddings_data = self.db.get(
|
||||||
|
where={
|
||||||
|
"question": src[0],
|
||||||
|
},
|
||||||
|
limit=1,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
|
||||||
|
except Exception:
|
||||||
|
existing_doc_id = None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"{chunker.data_type} is type {type(chunker.data_type)}. "
|
||||||
|
"When it should be DirectDataType, IndirectDataType or SpecialDataType."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create chunks
|
||||||
embeddings_data = chunker.create_chunks(loader, src)
|
embeddings_data = chunker.create_chunks(loader, src)
|
||||||
|
|
||||||
# spread chunking results
|
# spread chunking results
|
||||||
|
|||||||
@@ -11,9 +11,7 @@ class LocalQnaPairLoader(BaseLoader):
|
|||||||
question, answer = content
|
question, answer = content
|
||||||
content = f"Q: {question}\nA: {answer}"
|
content = f"Q: {question}\nA: {answer}"
|
||||||
url = "local"
|
url = "local"
|
||||||
meta_data = {
|
meta_data = {"url": url, "question": question}
|
||||||
"url": url,
|
|
||||||
}
|
|
||||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||||
return {
|
return {
|
||||||
"doc_id": doc_id,
|
"doc_id": doc_id,
|
||||||
|
|||||||
@@ -1,15 +1,47 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class DataType(Enum):
|
class DirectDataType(Enum):
|
||||||
|
"""
|
||||||
|
DirectDataType enum contains data types that contain raw data directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TEXT = "text"
|
||||||
|
|
||||||
|
|
||||||
|
class IndirectDataType(Enum):
|
||||||
|
"""
|
||||||
|
IndirectDataType enum contains data types that contain references to data stored elsewhere.
|
||||||
|
"""
|
||||||
|
|
||||||
YOUTUBE_VIDEO = "youtube_video"
|
YOUTUBE_VIDEO = "youtube_video"
|
||||||
PDF_FILE = "pdf_file"
|
PDF_FILE = "pdf_file"
|
||||||
WEB_PAGE = "web_page"
|
WEB_PAGE = "web_page"
|
||||||
SITEMAP = "sitemap"
|
SITEMAP = "sitemap"
|
||||||
DOCX = "docx"
|
DOCX = "docx"
|
||||||
DOCS_SITE = "docs_site"
|
DOCS_SITE = "docs_site"
|
||||||
TEXT = "text"
|
|
||||||
QNA_PAIR = "qna_pair"
|
|
||||||
NOTION = "notion"
|
NOTION = "notion"
|
||||||
CSV = "csv"
|
CSV = "csv"
|
||||||
MDX = "mdx"
|
MDX = "mdx"
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialDataType(Enum):
|
||||||
|
"""
|
||||||
|
SpecialDataType enum contains data types that are neither direct nor indirect, or simply require special attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
QNA_PAIR = "qna_pair"
|
||||||
|
|
||||||
|
|
||||||
|
class DataType(Enum):
|
||||||
|
TEXT = DirectDataType.TEXT.value
|
||||||
|
YOUTUBE_VIDEO = IndirectDataType.YOUTUBE_VIDEO.value
|
||||||
|
PDF_FILE = IndirectDataType.PDF_FILE.value
|
||||||
|
WEB_PAGE = IndirectDataType.WEB_PAGE.value
|
||||||
|
SITEMAP = IndirectDataType.SITEMAP.value
|
||||||
|
DOCX = IndirectDataType.DOCX.value
|
||||||
|
DOCS_SITE = IndirectDataType.DOCS_SITE.value
|
||||||
|
NOTION = IndirectDataType.NOTION.value
|
||||||
|
CSV = IndirectDataType.CSV.value
|
||||||
|
MDX = IndirectDataType.MDX.value
|
||||||
|
QNA_PAIR = SpecialDataType.QNA_PAIR.value
|
||||||
|
|||||||
32
tests/models/test_data_type.py
Normal file
32
tests/models/test_data_type.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||||
|
IndirectDataType, SpecialDataType)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataTypeEnums(unittest.TestCase):
|
||||||
|
def test_subclass_types_in_data_type(self):
|
||||||
|
"""Test that all data type category subclasses are contained in the composite data type"""
|
||||||
|
# Check if DirectDataType values are in DataType
|
||||||
|
for data_type in DirectDataType:
|
||||||
|
self.assertIn(data_type.value, DataType._value2member_map_)
|
||||||
|
|
||||||
|
# Check if IndirectDataType values are in DataType
|
||||||
|
for data_type in IndirectDataType:
|
||||||
|
self.assertIn(data_type.value, DataType._value2member_map_)
|
||||||
|
|
||||||
|
# Check if SpecialDataType values are in DataType
|
||||||
|
for data_type in SpecialDataType:
|
||||||
|
self.assertIn(data_type.value, DataType._value2member_map_)
|
||||||
|
|
||||||
|
def test_data_type_in_subclasses(self):
|
||||||
|
"""Test that all data types in the composite data type are categorized in a subclass"""
|
||||||
|
for data_type in DataType:
|
||||||
|
if data_type.value in DirectDataType._value2member_map_:
|
||||||
|
self.assertIn(data_type.value, DirectDataType._value2member_map_)
|
||||||
|
elif data_type.value in IndirectDataType._value2member_map_:
|
||||||
|
self.assertIn(data_type.value, IndirectDataType._value2member_map_)
|
||||||
|
elif data_type.value in SpecialDataType._value2member_map_:
|
||||||
|
self.assertIn(data_type.value, SpecialDataType._value2member_map_)
|
||||||
|
else:
|
||||||
|
self.fail(f"{data_type.value} not found in any subclass enums")
|
||||||
Reference in New Issue
Block a user