489 lines
18 KiB
Python
489 lines
18 KiB
Python
import itertools
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import string
|
|
from typing import Any
|
|
|
|
from schema import Optional, Or, Schema
|
|
from tqdm import tqdm
|
|
|
|
from embedchain.models.data_type import DataType
|
|
|
|
|
|
def parse_content(content, type):
|
|
implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"]
|
|
if type not in implemented:
|
|
raise ValueError(f"Parser type {type} not implemented. Please choose one of {implemented}")
|
|
|
|
from bs4 import BeautifulSoup
|
|
|
|
soup = BeautifulSoup(content, type)
|
|
original_size = len(str(soup.get_text()))
|
|
|
|
tags_to_exclude = [
|
|
"nav",
|
|
"aside",
|
|
"form",
|
|
"header",
|
|
"noscript",
|
|
"svg",
|
|
"canvas",
|
|
"footer",
|
|
"script",
|
|
"style",
|
|
]
|
|
for tag in soup(tags_to_exclude):
|
|
tag.decompose()
|
|
|
|
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
|
|
for id in ids_to_exclude:
|
|
tags = soup.find_all(id=id)
|
|
for tag in tags:
|
|
tag.decompose()
|
|
|
|
classes_to_exclude = [
|
|
"elementor-location-header",
|
|
"navbar-header",
|
|
"nav",
|
|
"header-sidebar-wrapper",
|
|
"blog-sidebar-wrapper",
|
|
"related-posts",
|
|
]
|
|
for class_name in classes_to_exclude:
|
|
tags = soup.find_all(class_=class_name)
|
|
for tag in tags:
|
|
tag.decompose()
|
|
|
|
content = soup.get_text()
|
|
content = clean_string(content)
|
|
|
|
cleaned_size = len(content)
|
|
if original_size != 0:
|
|
logging.info(
|
|
f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
|
|
)
|
|
|
|
return content
|
|
|
|
|
|
def clean_string(text):
|
|
"""
|
|
This function takes in a string and performs a series of text cleaning operations.
|
|
|
|
Args:
|
|
text (str): The text to be cleaned. This is expected to be a string.
|
|
|
|
Returns:
|
|
cleaned_text (str): The cleaned text after all the cleaning operations
|
|
have been performed.
|
|
"""
|
|
# Replacement of newline characters:
|
|
text = text.replace("\n", " ")
|
|
|
|
# Stripping and reducing multiple spaces to single:
|
|
cleaned_text = re.sub(r"\s+", " ", text.strip())
|
|
|
|
# Removing backslashes:
|
|
cleaned_text = cleaned_text.replace("\\", "")
|
|
|
|
# Replacing hash characters:
|
|
cleaned_text = cleaned_text.replace("#", " ")
|
|
|
|
# Eliminating consecutive non-alphanumeric characters:
|
|
# This regex identifies consecutive non-alphanumeric characters (i.e., not
|
|
# a word character [a-zA-Z0-9_] and not a whitespace) in the string
|
|
# and replaces each group of such characters with a single occurrence of
|
|
# that character.
|
|
# For example, "!!! hello !!!" would become "! hello !".
|
|
cleaned_text = re.sub(r"([^\w\s])\1*", r"\1", cleaned_text)
|
|
|
|
return cleaned_text
|
|
|
|
|
|
def is_readable(s):
|
|
"""
|
|
Heuristic to determine if a string is "readable" (mostly contains printable characters and forms meaningful words)
|
|
|
|
:param s: string
|
|
:return: True if the string is more than 95% printable.
|
|
"""
|
|
try:
|
|
printable_ratio = sum(c in string.printable for c in s) / len(s)
|
|
except ZeroDivisionError:
|
|
logging.warning("Empty string processed as unreadable")
|
|
printable_ratio = 0
|
|
return printable_ratio > 0.95 # 95% of characters are printable
|
|
|
|
|
|
def use_pysqlite3():
|
|
"""
|
|
Swap std-lib sqlite3 with pysqlite3.
|
|
"""
|
|
import platform
|
|
import sqlite3
|
|
|
|
if platform.system() == "Linux" and sqlite3.sqlite_version_info < (3, 35, 0):
|
|
try:
|
|
# According to the Chroma team, this patch only works on Linux
|
|
import datetime
|
|
import subprocess
|
|
import sys
|
|
|
|
subprocess.check_call(
|
|
[sys.executable, "-m", "pip", "install", "pysqlite3-binary", "--quiet", "--disable-pip-version-check"]
|
|
)
|
|
|
|
__import__("pysqlite3")
|
|
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
|
|
|
# Let the user know what happened.
|
|
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
|
|
print(
|
|
f"{current_time} [embedchain] [INFO]",
|
|
"Swapped std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
|
|
f"Your original version was {sqlite3.sqlite_version}.",
|
|
)
|
|
except Exception as e:
|
|
# Escape all exceptions
|
|
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
|
|
print(
|
|
f"{current_time} [embedchain] [ERROR]",
|
|
"Failed to swap std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
|
|
"Error:",
|
|
e,
|
|
)
|
|
|
|
|
|
def format_source(source: str, limit: int = 20) -> str:
|
|
"""
|
|
Format a string to only take the first x and last x letters.
|
|
This makes it easier to display a URL, keeping familiarity while ensuring a consistent length.
|
|
If the string is too short, it is not sliced.
|
|
"""
|
|
if len(source) > 2 * limit:
|
|
return source[:limit] + "..." + source[-limit:]
|
|
return source
|
|
|
|
|
|
def detect_datatype(source: Any) -> DataType:
|
|
"""
|
|
Automatically detect the datatype of the given source.
|
|
|
|
:param source: the source to base the detection on
|
|
:return: data_type string
|
|
"""
|
|
from urllib.parse import urlparse
|
|
|
|
import requests
|
|
import yaml
|
|
|
|
def is_openapi_yaml(yaml_content):
|
|
# currently the following two fields are required in openapi spec yaml config
|
|
return "openapi" in yaml_content and "info" in yaml_content
|
|
|
|
def is_google_drive_folder(url):
|
|
# checks if url is a Google Drive folder url against a regex
|
|
regex = r"^drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
|
|
return re.match(regex, url)
|
|
|
|
try:
|
|
if not isinstance(source, str):
|
|
raise ValueError("Source is not a string and thus cannot be a URL.")
|
|
url = urlparse(source)
|
|
# Check if both scheme and netloc are present. Local file system URIs are acceptable too.
|
|
if not all([url.scheme, url.netloc]) and url.scheme != "file":
|
|
raise ValueError("Not a valid URL.")
|
|
except ValueError:
|
|
url = False
|
|
|
|
formatted_source = format_source(str(source), 30)
|
|
|
|
if url:
|
|
YOUTUBE_ALLOWED_NETLOCKS = {
|
|
"www.youtube.com",
|
|
"m.youtube.com",
|
|
"youtu.be",
|
|
"youtube.com",
|
|
"vid.plus",
|
|
"www.youtube-nocookie.com",
|
|
}
|
|
|
|
if url.netloc in YOUTUBE_ALLOWED_NETLOCKS:
|
|
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
|
return DataType.YOUTUBE_VIDEO
|
|
|
|
if url.netloc in {"notion.so", "notion.site"}:
|
|
logging.debug(f"Source of `{formatted_source}` detected as `notion`.")
|
|
return DataType.NOTION
|
|
|
|
if url.path.endswith(".pdf"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
|
|
return DataType.PDF_FILE
|
|
|
|
if url.path.endswith(".xml"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
|
|
return DataType.SITEMAP
|
|
|
|
if url.path.endswith(".csv"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
|
|
return DataType.CSV
|
|
|
|
if url.path.endswith(".mdx") or url.path.endswith(".md"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
|
|
return DataType.MDX
|
|
|
|
if url.path.endswith(".docx"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
|
return DataType.DOCX
|
|
|
|
if url.path.endswith(".yaml"):
|
|
try:
|
|
response = requests.get(source)
|
|
response.raise_for_status()
|
|
try:
|
|
yaml_content = yaml.safe_load(response.text)
|
|
except yaml.YAMLError as exc:
|
|
logging.error(f"Error parsing YAML: {exc}")
|
|
raise TypeError(f"Not a valid data type. Error loading YAML: {exc}")
|
|
|
|
if is_openapi_yaml(yaml_content):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
|
|
return DataType.OPENAPI
|
|
else:
|
|
logging.error(
|
|
f"Source of `{formatted_source}` does not contain all the required \
|
|
fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
|
|
)
|
|
raise TypeError(
|
|
"Not a valid data type. Check 'https://spec.openapis.org/oas/v3.1.0', \
|
|
make sure you have all the required fields in YAML config data"
|
|
)
|
|
except requests.exceptions.RequestException as e:
|
|
logging.error(f"Error fetching URL {formatted_source}: {e}")
|
|
|
|
if url.path.endswith(".json"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `json_file`.")
|
|
return DataType.JSON
|
|
|
|
if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
|
|
# `docs_site` detection via path is not accepted for local filesystem URIs,
|
|
# because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
|
|
logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
|
|
return DataType.DOCS_SITE
|
|
|
|
if "github.com" in url.netloc:
|
|
logging.debug(f"Source of `{formatted_source}` detected as `github`.")
|
|
return DataType.GITHUB
|
|
|
|
if is_google_drive_folder(url.netloc + url.path):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
|
|
return DataType.GOOGLE_DRIVE_FOLDER
|
|
|
|
# If none of the above conditions are met, it's a general web page
|
|
logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
|
|
return DataType.WEB_PAGE
|
|
|
|
elif not isinstance(source, str):
|
|
# For datatypes where source is not a string.
|
|
|
|
if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
|
|
return DataType.QNA_PAIR
|
|
|
|
# Raise an error if it isn't a string and also not a valid non-string type (one of the previous).
|
|
# We could stringify it, but it is better to raise an error and let the user decide how they want to do that.
|
|
raise TypeError(
|
|
"Source is not a string and a valid non-string type could not be detected. If you want to embed it, please stringify it, for instance by using `str(source)` or `(', ').join(source)`." # noqa: E501
|
|
)
|
|
|
|
elif os.path.isfile(source):
|
|
# For datatypes that support conventional file references.
|
|
# Note: checking for string is not necessary anymore.
|
|
|
|
if source.endswith(".docx"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
|
return DataType.DOCX
|
|
|
|
if source.endswith(".csv"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
|
|
return DataType.CSV
|
|
|
|
if source.endswith(".xml"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
|
|
return DataType.XML
|
|
|
|
if source.endswith(".mdx") or source.endswith(".md"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
|
|
return DataType.MDX
|
|
|
|
if source.endswith(".txt"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `text`.")
|
|
return DataType.TEXT_FILE
|
|
|
|
if source.endswith(".pdf"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
|
|
return DataType.PDF_FILE
|
|
|
|
if source.endswith(".yaml"):
|
|
with open(source, "r") as file:
|
|
yaml_content = yaml.safe_load(file)
|
|
if is_openapi_yaml(yaml_content):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
|
|
return DataType.OPENAPI
|
|
else:
|
|
logging.error(
|
|
f"Source of `{formatted_source}` does not contain all the required \
|
|
fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
|
|
)
|
|
raise ValueError(
|
|
"Invalid YAML data. Check 'https://spec.openapis.org/oas/v3.1.0', \
|
|
make sure to add all the required params"
|
|
)
|
|
|
|
if source.endswith(".json"):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `json`.")
|
|
return DataType.JSON
|
|
|
|
if os.path.exists(source) and is_readable(open(source).read()):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `text_file`.")
|
|
return DataType.TEXT_FILE
|
|
|
|
# If the source is a valid file, that's not detectable as a type, an error is raised.
|
|
# It does not fall back to text.
|
|
raise ValueError(
|
|
"Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)." # noqa: E501
|
|
)
|
|
|
|
else:
|
|
# Source is not a URL.
|
|
|
|
# TODO: check if source is gmail query
|
|
|
|
# check if the source is valid json string
|
|
if is_valid_json_string(source):
|
|
logging.debug(f"Source of `{formatted_source}` detected as `json`.")
|
|
return DataType.JSON
|
|
|
|
# Use text as final fallback.
|
|
logging.debug(f"Source of `{formatted_source}` detected as `text`.")
|
|
return DataType.TEXT
|
|
|
|
|
|
# check if the source is valid json string
|
|
def is_valid_json_string(source: str):
|
|
try:
|
|
_ = json.loads(source)
|
|
return True
|
|
except json.JSONDecodeError:
|
|
return False
|
|
|
|
|
|
def validate_config(config_data):
|
|
schema = Schema(
|
|
{
|
|
Optional("app"): {
|
|
Optional("config"): {
|
|
Optional("id"): str,
|
|
Optional("name"): str,
|
|
Optional("log_level"): Or("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"),
|
|
Optional("collect_metrics"): bool,
|
|
Optional("collection_name"): str,
|
|
}
|
|
},
|
|
Optional("llm"): {
|
|
Optional("provider"): Or(
|
|
"openai",
|
|
"azure_openai",
|
|
"anthropic",
|
|
"huggingface",
|
|
"cohere",
|
|
"together",
|
|
"gpt4all",
|
|
"ollama",
|
|
"jina",
|
|
"llama2",
|
|
"vertexai",
|
|
"google",
|
|
),
|
|
Optional("config"): {
|
|
Optional("model"): str,
|
|
Optional("number_documents"): int,
|
|
Optional("temperature"): float,
|
|
Optional("max_tokens"): int,
|
|
Optional("top_p"): Or(float, int),
|
|
Optional("stream"): bool,
|
|
Optional("template"): str,
|
|
Optional("prompt"): str,
|
|
Optional("system_prompt"): str,
|
|
Optional("deployment_name"): str,
|
|
Optional("where"): dict,
|
|
Optional("query_type"): str,
|
|
Optional("api_key"): str,
|
|
Optional("endpoint"): str,
|
|
},
|
|
},
|
|
Optional("vectordb"): {
|
|
Optional("provider"): Or(
|
|
"chroma", "elasticsearch", "opensearch", "pinecone", "qdrant", "weaviate", "zilliz"
|
|
),
|
|
Optional("config"): object, # TODO: add particular config schema for each provider
|
|
},
|
|
Optional("embedder"): {
|
|
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
|
|
Optional("config"): {
|
|
Optional("model"): Optional(str),
|
|
Optional("deployment_name"): Optional(str),
|
|
Optional("api_key"): str,
|
|
Optional("title"): str,
|
|
Optional("task_type"): str,
|
|
Optional("vector_dimension"): int,
|
|
},
|
|
},
|
|
Optional("embedding_model"): {
|
|
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
|
|
Optional("config"): {
|
|
Optional("model"): str,
|
|
Optional("deployment_name"): str,
|
|
Optional("api_key"): str,
|
|
Optional("title"): str,
|
|
Optional("task_type"): str,
|
|
Optional("vector_dimension"): int,
|
|
},
|
|
},
|
|
Optional("chunker"): {
|
|
Optional("chunk_size"): int,
|
|
Optional("chunk_overlap"): int,
|
|
Optional("length_function"): str,
|
|
Optional("min_chunk_size"): int,
|
|
},
|
|
Optional("cache"): {
|
|
Optional("similarity_evaluation"): {
|
|
Optional("strategy"): Or("distance", "exact"),
|
|
Optional("max_distance"): float,
|
|
Optional("positive"): bool,
|
|
},
|
|
Optional("config"): {
|
|
Optional("similarity_threshold"): float,
|
|
Optional("auto_flush"): int,
|
|
},
|
|
},
|
|
}
|
|
)
|
|
|
|
return schema.validate(config_data)
|
|
|
|
|
|
def chunks(iterable, batch_size=100, desc="Processing chunks"):
|
|
"""A helper function to break an iterable into chunks of size batch_size."""
|
|
it = iter(iterable)
|
|
total_size = len(iterable)
|
|
|
|
with tqdm(total=total_size, desc=desc, unit="batch") as pbar:
|
|
chunk = tuple(itertools.islice(it, batch_size))
|
|
while chunk:
|
|
yield chunk
|
|
pbar.update(len(chunk))
|
|
chunk = tuple(itertools.islice(it, batch_size))
|