[feat]: Add openapi spec data loader (#818)
This commit is contained in:
18
embedchain/chunkers/openapi.py
Normal file
18
embedchain/chunkers/openapi.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
|
||||
|
||||
class OpenAPIChunker(BaseChunker):
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
@@ -5,6 +5,7 @@ from embedchain.chunkers.images import ImagesChunker
|
||||
from embedchain.chunkers.json import JSONChunker
|
||||
from embedchain.chunkers.mdx import MdxChunker
|
||||
from embedchain.chunkers.notion import NotionChunker
|
||||
from embedchain.chunkers.openapi import OpenAPIChunker
|
||||
from embedchain.chunkers.pdf_file import PdfFileChunker
|
||||
from embedchain.chunkers.qna_pair import QnaPairChunker
|
||||
from embedchain.chunkers.sitemap import SitemapChunker
|
||||
@@ -26,6 +27,7 @@ from embedchain.loaders.json import JSONLoader
|
||||
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
|
||||
from embedchain.loaders.local_text import LocalTextLoader
|
||||
from embedchain.loaders.mdx import MdxLoader
|
||||
from embedchain.loaders.openapi import OpenAPILoader
|
||||
from embedchain.loaders.pdf_file import PdfFileLoader
|
||||
from embedchain.loaders.sitemap import SitemapLoader
|
||||
from embedchain.loaders.unstructured_file import UnstructuredLoader
|
||||
@@ -81,6 +83,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.IMAGES: ImagesLoader,
|
||||
DataType.UNSTRUCTURED: UnstructuredLoader,
|
||||
DataType.JSON: JSONLoader,
|
||||
DataType.OPENAPI: OpenAPILoader,
|
||||
}
|
||||
lazy_loaders = {DataType.NOTION}
|
||||
if data_type in loaders:
|
||||
@@ -124,6 +127,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.XML: XmlChunker,
|
||||
DataType.UNSTRUCTURED: UnstructuredFileChunker,
|
||||
DataType.JSON: JSONChunker,
|
||||
DataType.OPENAPI: OpenAPIChunker,
|
||||
}
|
||||
if data_type in chunker_classes:
|
||||
chunker_class: type = chunker_classes[data_type]
|
||||
|
||||
42
embedchain/loaders/openapi.py
Normal file
42
embedchain/loaders/openapi.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import hashlib
|
||||
from io import StringIO
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
class OpenAPILoader(BaseLoader):
|
||||
@staticmethod
|
||||
def _get_file_content(content):
|
||||
url = urlparse(content)
|
||||
if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]:
|
||||
raise ValueError("Not a valid URL.")
|
||||
|
||||
if url.scheme in ["http", "https"]:
|
||||
response = requests.get(content)
|
||||
response.raise_for_status()
|
||||
return StringIO(response.text)
|
||||
elif url.scheme == "file":
|
||||
path = url.path
|
||||
return open(path)
|
||||
else:
|
||||
return open(content)
|
||||
|
||||
@staticmethod
|
||||
def load_data(content):
|
||||
"""Load yaml file of openapi. Each pair is a document."""
|
||||
data = []
|
||||
file_path = content
|
||||
data_content = []
|
||||
with OpenAPILoader._get_file_content(content=content) as file:
|
||||
yaml_data = yaml.load(file, Loader=yaml.Loader)
|
||||
for i, (key, value) in enumerate(yaml_data.items()):
|
||||
string_data = f"{key}: {value}"
|
||||
meta_data = {"url": file_path, "row": i + 1}
|
||||
data.append({"content": string_data, "meta_data": meta_data})
|
||||
data_content.append(string_data)
|
||||
doc_id = hashlib.sha256((content + ", ".join(data_content)).encode()).hexdigest()
|
||||
return {"doc_id": doc_id, "data": data}
|
||||
@@ -27,6 +27,7 @@ class IndirectDataType(Enum):
|
||||
IMAGES = "images"
|
||||
UNSTRUCTURED = "unstructured"
|
||||
JSON = "json"
|
||||
OPENAPI = "openapi"
|
||||
|
||||
|
||||
class SpecialDataType(Enum):
|
||||
@@ -53,3 +54,4 @@ class DataType(Enum):
|
||||
IMAGES = IndirectDataType.IMAGES.value
|
||||
UNSTRUCTURED = IndirectDataType.UNSTRUCTURED.value
|
||||
JSON = IndirectDataType.JSON.value
|
||||
OPENAPI = IndirectDataType.OPENAPI.value
|
||||
|
||||
@@ -115,6 +115,13 @@ def detect_datatype(source: Any) -> DataType:
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
def is_openapi_yaml(yaml_content):
|
||||
# currently the following two fields are required in openapi spec yaml config
|
||||
return "openapi" in yaml_content and "info" in yaml_content
|
||||
|
||||
try:
|
||||
if not isinstance(source, str):
|
||||
raise ValueError("Source is not a string and thus cannot be a URL.")
|
||||
@@ -155,6 +162,31 @@ def detect_datatype(source: Any) -> DataType:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||
return DataType.DOCX
|
||||
|
||||
if url.path.endswith(".yaml"):
|
||||
try:
|
||||
response = requests.get(source)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
yaml_content = yaml.safe_load(response.text)
|
||||
except yaml.YAMLError as exc:
|
||||
logging.error(f"Error parsing YAML: {exc}")
|
||||
raise TypeError(f"Not a valid data type. Error loading YAML: {exc}")
|
||||
|
||||
if is_openapi_yaml(yaml_content):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
|
||||
return DataType.OPENAPI
|
||||
else:
|
||||
logging.error(
|
||||
f"Source of `{formatted_source}` does not contain all the required \
|
||||
fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
|
||||
)
|
||||
raise TypeError(
|
||||
"Not a valid data type. Check 'https://spec.openapis.org/oas/v3.1.0', \
|
||||
make sure you have all the required fields in YAML config data"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Error fetching URL {formatted_source}: {e}")
|
||||
|
||||
if url.path.endswith(".json"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `json_file`.")
|
||||
return DataType.JSON
|
||||
@@ -198,6 +230,22 @@ def detect_datatype(source: Any) -> DataType:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
|
||||
return DataType.XML
|
||||
|
||||
if source.endswith(".yaml"):
|
||||
with open(source, "r") as file:
|
||||
yaml_content = yaml.safe_load(file)
|
||||
if is_openapi_yaml(yaml_content):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
|
||||
return DataType.OPENAPI
|
||||
else:
|
||||
logging.error(
|
||||
f"Source of `{formatted_source}` does not contain all the required \
|
||||
fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
|
||||
)
|
||||
raise ValueError(
|
||||
"Invalid YAML data. Check 'https://spec.openapis.org/oas/v3.1.0', \
|
||||
make sure to add all the required params"
|
||||
)
|
||||
|
||||
if source.endswith(".json"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `json`.")
|
||||
return DataType.JSON
|
||||
|
||||
Reference in New Issue
Block a user