feat: add method - detect format / data_type (#380)

This commit is contained in:
cachho
2023-08-16 22:18:24 +02:00
committed by GitHub
parent f92e890aa1
commit 4c8876f032
18 changed files with 472 additions and 121 deletions

View File

@@ -1,5 +1,7 @@
import hashlib
from embedchain.models.data_type import DataType
class BaseChunker:
def __init__(self, text_splitter):
@@ -26,7 +28,7 @@ class BaseChunker:
meta_data = data["meta_data"]
# add data type to meta data to allow query using data type
meta_data["data_type"] = self.data_type
meta_data["data_type"] = self.data_type.value
url = meta_data["url"]
chunks = self.get_chunks(content)
@@ -52,8 +54,10 @@ class BaseChunker:
"""
return self.text_splitter.split_text(content)
def set_data_type(self, data_type):
def set_data_type(self, data_type: DataType):
"""
set the data type of chunker
"""
self.data_type = data_type
# TODO: This should be done during initialization. This means it has to be done in the child classes.

View File

@@ -15,6 +15,7 @@ from embedchain.loaders.pdf_file import PdfFileLoader
from embedchain.loaders.sitemap import SitemapLoader
from embedchain.loaders.web_page import WebPageLoader
from embedchain.loaders.youtube_video import YoutubeVideoLoader
from embedchain.models.data_type import DataType
class DataFormatter:
@@ -24,11 +25,11 @@ class DataFormatter:
.add or .add_local method call
"""
def __init__(self, data_type: str, config: AddConfig):
def __init__(self, data_type: DataType, config: AddConfig):
self.loader = self._get_loader(data_type, config.loader)
self.chunker = self._get_chunker(data_type, config.chunker)
def _get_loader(self, data_type, config):
def _get_loader(self, data_type: DataType, config):
"""
Returns the appropriate data loader for the given data type.
@@ -37,22 +38,22 @@ class DataFormatter:
:raises ValueError: If an unsupported data type is provided.
"""
loaders = {
"youtube_video": YoutubeVideoLoader,
"pdf_file": PdfFileLoader,
"web_page": WebPageLoader,
"qna_pair": LocalQnaPairLoader,
"text": LocalTextLoader,
"docx": DocxFileLoader,
"sitemap": SitemapLoader,
"docs_site": DocsSiteLoader,
DataType.YOUTUBE_VIDEO: YoutubeVideoLoader,
DataType.PDF_FILE: PdfFileLoader,
DataType.WEB_PAGE: WebPageLoader,
DataType.QNA_PAIR: LocalQnaPairLoader,
DataType.TEXT: LocalTextLoader,
DataType.DOCX: DocxFileLoader,
DataType.SITEMAP: SitemapLoader,
DataType.DOCS_SITE: DocsSiteLoader,
}
lazy_loaders = ("notion",)
lazy_loaders = {DataType.NOTION}
if data_type in loaders:
loader_class = loaders[data_type]
loader = loader_class()
return loader
elif data_type in lazy_loaders:
if data_type == "notion":
if data_type == DataType.NOTION:
from embedchain.loaders.notion import NotionLoader
return NotionLoader()
@@ -61,7 +62,7 @@ class DataFormatter:
else:
raise ValueError(f"Unsupported data type: {data_type}")
def _get_chunker(self, data_type, config):
def _get_chunker(self, data_type: DataType, config):
"""
Returns the appropriate chunker for the given data type.
@@ -70,15 +71,15 @@ class DataFormatter:
:raises ValueError: If an unsupported data type is provided.
"""
chunker_classes = {
"youtube_video": YoutubeVideoChunker,
"pdf_file": PdfFileChunker,
"web_page": WebPageChunker,
"qna_pair": QnaPairChunker,
"text": TextChunker,
"docx": DocxFileChunker,
"sitemap": WebPageChunker,
"docs_site": DocsSiteChunker,
"notion": NotionChunker,
DataType.YOUTUBE_VIDEO: YoutubeVideoChunker,
DataType.PDF_FILE: PdfFileChunker,
DataType.WEB_PAGE: WebPageChunker,
DataType.QNA_PAIR: QnaPairChunker,
DataType.TEXT: TextChunker,
DataType.DOCX: DocxFileChunker,
DataType.WEB_PAGE: WebPageChunker,
DataType.DOCS_SITE: DocsSiteChunker,
DataType.NOTION: NotionChunker,
}
if data_type in chunker_classes:
chunker_class = chunker_classes[data_type]

View File

@@ -1,9 +1,10 @@
import hashlib
import importlib.metadata
import logging
import os
import threading
import uuid
from typing import Optional
from typing import Dict, Optional
import requests
from dotenv import load_dotenv
@@ -17,6 +18,8 @@ from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
from embedchain.data_formatter import DataFormatter
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType
from embedchain.utils import detect_datatype
load_dotenv()
@@ -47,27 +50,62 @@ class EmbedChain:
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
thread_telemetry.start()
def add(self, data_type, url, metadata=None, config: AddConfig = None):
def add(
self,
source,
data_type: Optional[DataType] = None,
metadata: Optional[Dict] = None,
config: Optional[AddConfig] = None,
):
"""
Adds the data from the given URL to the vector db.
Loads the data, chunks it, create embedding for each chunk
and then stores the embedding to vector database.
:param data_type: The type of the data to add.
:param url: The URL where the data is located.
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
:param data_type: Optional. Automatically detected, but can be forced with this argument.
The type of the data to add.
:param metadata: Optional. Metadata associated with the data source.
:param config: Optional. The `AddConfig` instance to use as configuration
options.
:return: source_id, a md5-hash of the source, in hexadecimal representation.
"""
if config is None:
config = AddConfig()
try:
DataType(source)
logging.warning(
f"""Starting from version v0.0.39, Embedchain can automatically detect the data type. So, in the `add` method, the argument order has changed. You no longer need to specify '{source}' for the `source` argument. So the code snippet will be `.add("{data_type}", "{source}")`""" # noqa #E501
)
logging.warning(
"Embedchain is swapping the arguments for you. This functionality might be deprecated in the future, so please adjust your code." # noqa #E501
)
source, data_type = data_type, source
except ValueError:
pass
if data_type:
try:
data_type = DataType(data_type)
except ValueError:
raise ValueError(
f"Invalid data_type: '{data_type}'.",
f"Please use one of the following: {[data_type.value for data_type in DataType]}",
) from None
if not data_type:
data_type = detect_datatype(source)
# `source_id` is the hash of the source argument
hash_object = hashlib.md5(str(source).encode("utf-8"))
source_id = hash_object.hexdigest()
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([data_type, url, metadata])
self.user_asks.append([source, data_type.value, metadata])
documents, _metadatas, _ids, new_chunks = self.load_and_embed(
data_formatter.loader, data_formatter.chunker, url, metadata
data_formatter.loader, data_formatter.chunker, source, metadata, source_id
)
if data_type in ("docs_site",):
if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True
# Send anonymous telemetry
@@ -75,41 +113,35 @@ class EmbedChain:
# it's quicker to check the variable twice than to count words when they won't be submitted.
word_count = sum([len(document.split(" ")) for document in documents])
extra_metadata = {"data_type": data_type, "word_count": word_count, "chunks_count": new_chunks}
extra_metadata = {"data_type": data_type.value, "word_count": word_count, "chunks_count": new_chunks}
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
thread_telemetry.start()
def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
return source_id
def add_local(self, source, data_type=None, metadata=None, config: AddConfig = None):
"""
Adds the data you supply to the vector db.
Warning:
This method is deprecated and will be removed in future versions. Use `add` instead.
Adds the data from the given URL to the vector db.
Loads the data, chunks it, create embedding for each chunk
and then stores the embedding to vector database.
:param data_type: The type of the data to add.
:param content: The local data. Refer to the `README` for formatting.
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
:param data_type: Optional. Automatically detected, but can be forced with this argument.
The type of the data to add.
:param metadata: Optional. Metadata associated with the data source.
:param config: Optional. The `AddConfig` instance to use as
configuration options.
:param config: Optional. The `AddConfig` instance to use as configuration
options.
:return: md5-hash of the source, in hexadecimal representation.
"""
if config is None:
config = AddConfig()
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([data_type, content])
documents, _metadatas, _ids, new_chunks = self.load_and_embed(
data_formatter.loader, data_formatter.chunker, content, metadata
logging.warning(
"The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501
)
return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
# Send anonymous telemetry
if self.config.collect_metrics:
# it's quicker to check the variable twice than to count words when they won't be submitted.
word_count = sum([len(document.split(" ")) for document in documents])
extra_metadata = {"data_type": data_type, "word_count": word_count, "chunks_count": new_chunks}
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add_local", extra_metadata))
thread_telemetry.start()
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None):
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None, source_id=None):
"""
Loads the data from the given URL, chunks it, and adds it to database.
@@ -118,12 +150,16 @@ class EmbedChain:
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param metadata: Optional. Metadata associated with the data source.
:param source_id: Hexadecimal hash of the source.
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
"""
embeddings_data = chunker.create_chunks(loader, src)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
ids = embeddings_data["ids"]
# get existing ids, and discard doc if any common id exist.
where = {"app_id": self.config.id} if self.config.id is not None else {}
# where={"url": src}
@@ -144,22 +180,31 @@ class EmbedChain:
ids = list(data_dict.keys())
documents, metadatas = zip(*data_dict.values())
# Add app id in metadatas so that they can be queried on later
if self.config.id is not None:
metadatas = [{**m, "app_id": self.config.id} for m in metadatas]
# Loop though all metadatas and add extras.
new_metadatas = []
for m in metadatas:
# Add app id in metadatas so that they can be queried on later
if self.config.id:
m["app_id"] = self.config.id
# FIXME: Fix the error handling logic when metadatas or metadata is None
metadatas = metadatas if metadatas else []
metadata = metadata if metadata else {}
# Add hashed source
m["hash"] = source_id
# Note: Metadata is the function argument
if metadata:
# Spread whatever is in metadata into the new object.
m.update(metadata)
new_metadatas.append(m)
metadatas = new_metadatas
# Count before, to calculate a delta in the end.
chunks_before_addition = self.count()
# Add metadata to each document
metadatas_with_metadata = [{**meta, **metadata} for meta in metadatas]
self.db.add(documents=documents, metadatas=metadatas_with_metadata, ids=ids)
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
count_new_chunks = self.count() - chunks_before_addition
print((f"Successfully saved {src}. New chunks count: {count_new_chunks}"))
return list(documents), metadatas_with_metadata, ids, count_new_chunks
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results):
return [

View File

@@ -0,0 +1,13 @@
from enum import Enum
class DataType(Enum):
YOUTUBE_VIDEO = "youtube_video"
PDF_FILE = "pdf_file"
WEB_PAGE = "web_page"
SITEMAP = "sitemap"
DOCX = "docx"
DOCS_SITE = "docs_site"
TEXT = "text"
QNA_PAIR = "qna_pair"
NOTION = "notion"

View File

@@ -1,6 +1,10 @@
import logging
import os
import re
import string
from typing import Any
from embedchain.models.data_type import DataType
def clean_string(text):
@@ -89,3 +93,113 @@ def use_pysqlite3():
"Error:",
e,
)
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
# Let the user know what happened.
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
print(
f"{current_time} [embedchain] [INFO]",
"Swapped std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
f"Your original version was {sqlite3.sqlite_version}.",
)
def format_source(source: str, limit: int = 20) -> str:
"""
Format a string to only take the first x and last x letters.
This makes it easier to display a URL, keeping familiarity while ensuring a consistent length.
If the string is too short, it is not sliced.
"""
if len(source) > 2 * limit:
return source[:limit] + "..." + source[-limit:]
return source
def detect_datatype(source: Any) -> DataType:
"""
Automatically detect the datatype of the given source.
:param source: the source to base the detection on
:return: data_type string
"""
from urllib.parse import urlparse
try:
if not isinstance(source, str):
raise ValueError("Source is not a string and thus cannot be a URL.")
url = urlparse(source)
# Check if both scheme and netloc are present. Local file system URIs are acceptable too.
if not all([url.scheme, url.netloc]) and url.scheme != "file":
raise ValueError("Not a valid URL.")
except ValueError:
url = False
formatted_source = format_source(str(source), 30)
if url:
from langchain.document_loaders.youtube import \
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
return DataType.YOUTUBE_VIDEO
if url.netloc in {"notion.so", "notion.site"}:
logging.debug(f"Source of `{formatted_source}` detected as `notion`.")
return DataType.NOTION
if url.path.endswith(".pdf"):
logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
return DataType.PDF_FILE
if url.path.endswith(".xml"):
logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
return DataType.SITEMAP
if url.path.endswith(".docx"):
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
return DataType.DOCX
if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
# `docs_site` detection via path is not accepted for local filesystem URIs,
# because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
return DataType.DOCS_SITE
# If none of the above conditions are met, it's a general web page
logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
return DataType.WEB_PAGE
elif not isinstance(source, str):
# For datatypes where source is not a string.
if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str):
logging.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
return DataType.QNA_PAIR
# Raise an error if it isn't a string and also not a valid non-string type (one of the previous).
# We could stringify it, but it is better to raise an error and let the user decide how they want to do that.
raise TypeError(
"Source is not a string and a valid non-string type could not be detected. If you want to embed it, please stringify it, for instance by using `str(source)` or `(', ').join(source)`." # noqa: E501
)
elif os.path.isfile(source):
# For datatypes that support conventional file references.
# Note: checking for string is not necessary anymore.
if source.endswith(".docx"):
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
return DataType.DOCX
# If the source is a valid file, that's not detectable as a type, an error is raised.
# It does not fallback to text.
raise ValueError(
"Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)." # noqa: E501
)
else:
# Source is not a URL.
# Use text as final fallback.
logging.debug(f"Source of `{formatted_source}` detected as `text`.")
return DataType.TEXT