feat: add method - detect format / data_type (#380)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class BaseChunker:
|
||||
def __init__(self, text_splitter):
|
||||
@@ -26,7 +28,7 @@ class BaseChunker:
|
||||
|
||||
meta_data = data["meta_data"]
|
||||
# add data type to meta data to allow query using data type
|
||||
meta_data["data_type"] = self.data_type
|
||||
meta_data["data_type"] = self.data_type.value
|
||||
url = meta_data["url"]
|
||||
|
||||
chunks = self.get_chunks(content)
|
||||
@@ -52,8 +54,10 @@ class BaseChunker:
|
||||
"""
|
||||
return self.text_splitter.split_text(content)
|
||||
|
||||
def set_data_type(self, data_type):
|
||||
def set_data_type(self, data_type: DataType):
|
||||
"""
|
||||
set the data type of chunker
|
||||
"""
|
||||
self.data_type = data_type
|
||||
|
||||
# TODO: This should be done during initialization. This means it has to be done in the child classes.
|
||||
|
||||
@@ -15,6 +15,7 @@ from embedchain.loaders.pdf_file import PdfFileLoader
|
||||
from embedchain.loaders.sitemap import SitemapLoader
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class DataFormatter:
|
||||
@@ -24,11 +25,11 @@ class DataFormatter:
|
||||
.add or .add_local method call
|
||||
"""
|
||||
|
||||
def __init__(self, data_type: str, config: AddConfig):
|
||||
def __init__(self, data_type: DataType, config: AddConfig):
|
||||
self.loader = self._get_loader(data_type, config.loader)
|
||||
self.chunker = self._get_chunker(data_type, config.chunker)
|
||||
|
||||
def _get_loader(self, data_type, config):
|
||||
def _get_loader(self, data_type: DataType, config):
|
||||
"""
|
||||
Returns the appropriate data loader for the given data type.
|
||||
|
||||
@@ -37,22 +38,22 @@ class DataFormatter:
|
||||
:raises ValueError: If an unsupported data type is provided.
|
||||
"""
|
||||
loaders = {
|
||||
"youtube_video": YoutubeVideoLoader,
|
||||
"pdf_file": PdfFileLoader,
|
||||
"web_page": WebPageLoader,
|
||||
"qna_pair": LocalQnaPairLoader,
|
||||
"text": LocalTextLoader,
|
||||
"docx": DocxFileLoader,
|
||||
"sitemap": SitemapLoader,
|
||||
"docs_site": DocsSiteLoader,
|
||||
DataType.YOUTUBE_VIDEO: YoutubeVideoLoader,
|
||||
DataType.PDF_FILE: PdfFileLoader,
|
||||
DataType.WEB_PAGE: WebPageLoader,
|
||||
DataType.QNA_PAIR: LocalQnaPairLoader,
|
||||
DataType.TEXT: LocalTextLoader,
|
||||
DataType.DOCX: DocxFileLoader,
|
||||
DataType.SITEMAP: SitemapLoader,
|
||||
DataType.DOCS_SITE: DocsSiteLoader,
|
||||
}
|
||||
lazy_loaders = ("notion",)
|
||||
lazy_loaders = {DataType.NOTION}
|
||||
if data_type in loaders:
|
||||
loader_class = loaders[data_type]
|
||||
loader = loader_class()
|
||||
return loader
|
||||
elif data_type in lazy_loaders:
|
||||
if data_type == "notion":
|
||||
if data_type == DataType.NOTION:
|
||||
from embedchain.loaders.notion import NotionLoader
|
||||
|
||||
return NotionLoader()
|
||||
@@ -61,7 +62,7 @@ class DataFormatter:
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
|
||||
def _get_chunker(self, data_type, config):
|
||||
def _get_chunker(self, data_type: DataType, config):
|
||||
"""
|
||||
Returns the appropriate chunker for the given data type.
|
||||
|
||||
@@ -70,15 +71,15 @@ class DataFormatter:
|
||||
:raises ValueError: If an unsupported data type is provided.
|
||||
"""
|
||||
chunker_classes = {
|
||||
"youtube_video": YoutubeVideoChunker,
|
||||
"pdf_file": PdfFileChunker,
|
||||
"web_page": WebPageChunker,
|
||||
"qna_pair": QnaPairChunker,
|
||||
"text": TextChunker,
|
||||
"docx": DocxFileChunker,
|
||||
"sitemap": WebPageChunker,
|
||||
"docs_site": DocsSiteChunker,
|
||||
"notion": NotionChunker,
|
||||
DataType.YOUTUBE_VIDEO: YoutubeVideoChunker,
|
||||
DataType.PDF_FILE: PdfFileChunker,
|
||||
DataType.WEB_PAGE: WebPageChunker,
|
||||
DataType.QNA_PAIR: QnaPairChunker,
|
||||
DataType.TEXT: TextChunker,
|
||||
DataType.DOCX: DocxFileChunker,
|
||||
DataType.WEB_PAGE: WebPageChunker,
|
||||
DataType.DOCS_SITE: DocsSiteChunker,
|
||||
DataType.NOTION: NotionChunker,
|
||||
}
|
||||
if data_type in chunker_classes:
|
||||
chunker_class = chunker_classes[data_type]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import hashlib
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
@@ -17,6 +18,8 @@ from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType
|
||||
from embedchain.utils import detect_datatype
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -47,27 +50,62 @@ class EmbedChain:
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
|
||||
thread_telemetry.start()
|
||||
|
||||
def add(self, data_type, url, metadata=None, config: AddConfig = None):
|
||||
def add(
|
||||
self,
|
||||
source,
|
||||
data_type: Optional[DataType] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
config: Optional[AddConfig] = None,
|
||||
):
|
||||
"""
|
||||
Adds the data from the given URL to the vector db.
|
||||
Loads the data, chunks it, create embedding for each chunk
|
||||
and then stores the embedding to vector database.
|
||||
|
||||
:param data_type: The type of the data to add.
|
||||
:param url: The URL where the data is located.
|
||||
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
|
||||
:param data_type: Optional. Automatically detected, but can be forced with this argument.
|
||||
The type of the data to add.
|
||||
:param metadata: Optional. Metadata associated with the data source.
|
||||
:param config: Optional. The `AddConfig` instance to use as configuration
|
||||
options.
|
||||
:return: source_id, a md5-hash of the source, in hexadecimal representation.
|
||||
"""
|
||||
if config is None:
|
||||
config = AddConfig()
|
||||
|
||||
try:
|
||||
DataType(source)
|
||||
logging.warning(
|
||||
f"""Starting from version v0.0.39, Embedchain can automatically detect the data type. So, in the `add` method, the argument order has changed. You no longer need to specify '{source}' for the `source` argument. So the code snippet will be `.add("{data_type}", "{source}")`""" # noqa #E501
|
||||
)
|
||||
logging.warning(
|
||||
"Embedchain is swapping the arguments for you. This functionality might be deprecated in the future, so please adjust your code." # noqa #E501
|
||||
)
|
||||
source, data_type = data_type, source
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if data_type:
|
||||
try:
|
||||
data_type = DataType(data_type)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid data_type: '{data_type}'.",
|
||||
f"Please use one of the following: {[data_type.value for data_type in DataType]}",
|
||||
) from None
|
||||
if not data_type:
|
||||
data_type = detect_datatype(source)
|
||||
|
||||
# `source_id` is the hash of the source argument
|
||||
hash_object = hashlib.md5(str(source).encode("utf-8"))
|
||||
source_id = hash_object.hexdigest()
|
||||
|
||||
data_formatter = DataFormatter(data_type, config)
|
||||
self.user_asks.append([data_type, url, metadata])
|
||||
self.user_asks.append([source, data_type.value, metadata])
|
||||
documents, _metadatas, _ids, new_chunks = self.load_and_embed(
|
||||
data_formatter.loader, data_formatter.chunker, url, metadata
|
||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_id
|
||||
)
|
||||
if data_type in ("docs_site",):
|
||||
if data_type in {DataType.DOCS_SITE}:
|
||||
self.is_docs_site_instance = True
|
||||
|
||||
# Send anonymous telemetry
|
||||
@@ -75,41 +113,35 @@ class EmbedChain:
|
||||
# it's quicker to check the variable twice than to count words when they won't be submitted.
|
||||
word_count = sum([len(document.split(" ")) for document in documents])
|
||||
|
||||
extra_metadata = {"data_type": data_type, "word_count": word_count, "chunks_count": new_chunks}
|
||||
extra_metadata = {"data_type": data_type.value, "word_count": word_count, "chunks_count": new_chunks}
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
|
||||
thread_telemetry.start()
|
||||
|
||||
def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
|
||||
return source_id
|
||||
|
||||
def add_local(self, source, data_type=None, metadata=None, config: AddConfig = None):
|
||||
"""
|
||||
Adds the data you supply to the vector db.
|
||||
Warning:
|
||||
This method is deprecated and will be removed in future versions. Use `add` instead.
|
||||
|
||||
Adds the data from the given URL to the vector db.
|
||||
Loads the data, chunks it, create embedding for each chunk
|
||||
and then stores the embedding to vector database.
|
||||
|
||||
:param data_type: The type of the data to add.
|
||||
:param content: The local data. Refer to the `README` for formatting.
|
||||
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
|
||||
:param data_type: Optional. Automatically detected, but can be forced with this argument.
|
||||
The type of the data to add.
|
||||
:param metadata: Optional. Metadata associated with the data source.
|
||||
:param config: Optional. The `AddConfig` instance to use as
|
||||
configuration options.
|
||||
:param config: Optional. The `AddConfig` instance to use as configuration
|
||||
options.
|
||||
:return: md5-hash of the source, in hexadecimal representation.
|
||||
"""
|
||||
if config is None:
|
||||
config = AddConfig()
|
||||
|
||||
data_formatter = DataFormatter(data_type, config)
|
||||
self.user_asks.append([data_type, content])
|
||||
documents, _metadatas, _ids, new_chunks = self.load_and_embed(
|
||||
data_formatter.loader, data_formatter.chunker, content, metadata
|
||||
logging.warning(
|
||||
"The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501
|
||||
)
|
||||
return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
|
||||
|
||||
# Send anonymous telemetry
|
||||
if self.config.collect_metrics:
|
||||
# it's quicker to check the variable twice than to count words when they won't be submitted.
|
||||
word_count = sum([len(document.split(" ")) for document in documents])
|
||||
|
||||
extra_metadata = {"data_type": data_type, "word_count": word_count, "chunks_count": new_chunks}
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add_local", extra_metadata))
|
||||
thread_telemetry.start()
|
||||
|
||||
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None):
|
||||
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None, source_id=None):
|
||||
"""
|
||||
Loads the data from the given URL, chunks it, and adds it to database.
|
||||
|
||||
@@ -118,12 +150,16 @@ class EmbedChain:
|
||||
:param src: The data to be handled by the loader. Can be a URL for
|
||||
remote sources or local content for local loaders.
|
||||
:param metadata: Optional. Metadata associated with the data source.
|
||||
:param source_id: Hexadecimal hash of the source.
|
||||
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||
"""
|
||||
embeddings_data = chunker.create_chunks(loader, src)
|
||||
|
||||
# spread chunking results
|
||||
documents = embeddings_data["documents"]
|
||||
metadatas = embeddings_data["metadatas"]
|
||||
ids = embeddings_data["ids"]
|
||||
|
||||
# get existing ids, and discard doc if any common id exist.
|
||||
where = {"app_id": self.config.id} if self.config.id is not None else {}
|
||||
# where={"url": src}
|
||||
@@ -144,22 +180,31 @@ class EmbedChain:
|
||||
ids = list(data_dict.keys())
|
||||
documents, metadatas = zip(*data_dict.values())
|
||||
|
||||
# Add app id in metadatas so that they can be queried on later
|
||||
if self.config.id is not None:
|
||||
metadatas = [{**m, "app_id": self.config.id} for m in metadatas]
|
||||
# Loop though all metadatas and add extras.
|
||||
new_metadatas = []
|
||||
for m in metadatas:
|
||||
# Add app id in metadatas so that they can be queried on later
|
||||
if self.config.id:
|
||||
m["app_id"] = self.config.id
|
||||
|
||||
# FIXME: Fix the error handling logic when metadatas or metadata is None
|
||||
metadatas = metadatas if metadatas else []
|
||||
metadata = metadata if metadata else {}
|
||||
# Add hashed source
|
||||
m["hash"] = source_id
|
||||
|
||||
# Note: Metadata is the function argument
|
||||
if metadata:
|
||||
# Spread whatever is in metadata into the new object.
|
||||
m.update(metadata)
|
||||
|
||||
new_metadatas.append(m)
|
||||
metadatas = new_metadatas
|
||||
|
||||
# Count before, to calculate a delta in the end.
|
||||
chunks_before_addition = self.count()
|
||||
|
||||
# Add metadata to each document
|
||||
metadatas_with_metadata = [{**meta, **metadata} for meta in metadatas]
|
||||
|
||||
self.db.add(documents=documents, metadatas=metadatas_with_metadata, ids=ids)
|
||||
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
count_new_chunks = self.count() - chunks_before_addition
|
||||
print((f"Successfully saved {src}. New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas_with_metadata, ids, count_new_chunks
|
||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
|
||||
def _format_result(self, results):
|
||||
return [
|
||||
|
||||
13
embedchain/models/data_type.py
Normal file
13
embedchain/models/data_type.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
YOUTUBE_VIDEO = "youtube_video"
|
||||
PDF_FILE = "pdf_file"
|
||||
WEB_PAGE = "web_page"
|
||||
SITEMAP = "sitemap"
|
||||
DOCX = "docx"
|
||||
DOCS_SITE = "docs_site"
|
||||
TEXT = "text"
|
||||
QNA_PAIR = "qna_pair"
|
||||
NOTION = "notion"
|
||||
@@ -1,6 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
def clean_string(text):
|
||||
@@ -89,3 +93,113 @@ def use_pysqlite3():
|
||||
"Error:",
|
||||
e,
|
||||
)
|
||||
__import__("pysqlite3")
|
||||
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
||||
# Let the user know what happened.
|
||||
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
|
||||
print(
|
||||
f"{current_time} [embedchain] [INFO]",
|
||||
"Swapped std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
|
||||
f"Your original version was {sqlite3.sqlite_version}.",
|
||||
)
|
||||
|
||||
|
||||
def format_source(source: str, limit: int = 20) -> str:
|
||||
"""
|
||||
Format a string to only take the first x and last x letters.
|
||||
This makes it easier to display a URL, keeping familiarity while ensuring a consistent length.
|
||||
If the string is too short, it is not sliced.
|
||||
"""
|
||||
if len(source) > 2 * limit:
|
||||
return source[:limit] + "..." + source[-limit:]
|
||||
return source
|
||||
|
||||
|
||||
def detect_datatype(source: Any) -> DataType:
|
||||
"""
|
||||
Automatically detect the datatype of the given source.
|
||||
|
||||
:param source: the source to base the detection on
|
||||
:return: data_type string
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
if not isinstance(source, str):
|
||||
raise ValueError("Source is not a string and thus cannot be a URL.")
|
||||
url = urlparse(source)
|
||||
# Check if both scheme and netloc are present. Local file system URIs are acceptable too.
|
||||
if not all([url.scheme, url.netloc]) and url.scheme != "file":
|
||||
raise ValueError("Not a valid URL.")
|
||||
except ValueError:
|
||||
url = False
|
||||
|
||||
formatted_source = format_source(str(source), 30)
|
||||
|
||||
if url:
|
||||
from langchain.document_loaders.youtube import \
|
||||
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
|
||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||
return DataType.YOUTUBE_VIDEO
|
||||
|
||||
if url.netloc in {"notion.so", "notion.site"}:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `notion`.")
|
||||
return DataType.NOTION
|
||||
|
||||
if url.path.endswith(".pdf"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
|
||||
return DataType.PDF_FILE
|
||||
|
||||
if url.path.endswith(".xml"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
|
||||
return DataType.SITEMAP
|
||||
|
||||
if url.path.endswith(".docx"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||
return DataType.DOCX
|
||||
|
||||
if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
|
||||
# `docs_site` detection via path is not accepted for local filesystem URIs,
|
||||
# because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
|
||||
return DataType.DOCS_SITE
|
||||
|
||||
# If none of the above conditions are met, it's a general web page
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
|
||||
return DataType.WEB_PAGE
|
||||
|
||||
elif not isinstance(source, str):
|
||||
# For datatypes where source is not a string.
|
||||
|
||||
if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
|
||||
return DataType.QNA_PAIR
|
||||
|
||||
# Raise an error if it isn't a string and also not a valid non-string type (one of the previous).
|
||||
# We could stringify it, but it is better to raise an error and let the user decide how they want to do that.
|
||||
raise TypeError(
|
||||
"Source is not a string and a valid non-string type could not be detected. If you want to embed it, please stringify it, for instance by using `str(source)` or `(', ').join(source)`." # noqa: E501
|
||||
)
|
||||
|
||||
elif os.path.isfile(source):
|
||||
# For datatypes that support conventional file references.
|
||||
# Note: checking for string is not necessary anymore.
|
||||
|
||||
if source.endswith(".docx"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||
return DataType.DOCX
|
||||
|
||||
# If the source is a valid file, that's not detectable as a type, an error is raised.
|
||||
# It does not fallback to text.
|
||||
raise ValueError(
|
||||
"Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)." # noqa: E501
|
||||
)
|
||||
|
||||
else:
|
||||
# Source is not a URL.
|
||||
|
||||
# Use text as final fallback.
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `text`.")
|
||||
return DataType.TEXT
|
||||
|
||||
Reference in New Issue
Block a user