[feat]: Add openapi spec data loader (#818)

This commit is contained in:
Deven Patel
2023-10-25 14:19:13 -07:00
committed by GitHub
parent f2a5dc40ee
commit 797bb567c6
13 changed files with 212 additions and 0 deletions

View File

@@ -0,0 +1,18 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
class OpenAPIChunker(BaseChunker):
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -5,6 +5,7 @@ from embedchain.chunkers.images import ImagesChunker
from embedchain.chunkers.json import JSONChunker
from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.notion import NotionChunker
from embedchain.chunkers.openapi import OpenAPIChunker
from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.chunkers.sitemap import SitemapChunker
@@ -26,6 +27,7 @@ from embedchain.loaders.json import JSONLoader
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
from embedchain.loaders.local_text import LocalTextLoader
from embedchain.loaders.mdx import MdxLoader
from embedchain.loaders.openapi import OpenAPILoader
from embedchain.loaders.pdf_file import PdfFileLoader
from embedchain.loaders.sitemap import SitemapLoader
from embedchain.loaders.unstructured_file import UnstructuredLoader
@@ -81,6 +83,7 @@ class DataFormatter(JSONSerializable):
DataType.IMAGES: ImagesLoader,
DataType.UNSTRUCTURED: UnstructuredLoader,
DataType.JSON: JSONLoader,
DataType.OPENAPI: OpenAPILoader,
}
lazy_loaders = {DataType.NOTION}
if data_type in loaders:
@@ -124,6 +127,7 @@ class DataFormatter(JSONSerializable):
DataType.XML: XmlChunker,
DataType.UNSTRUCTURED: UnstructuredFileChunker,
DataType.JSON: JSONChunker,
DataType.OPENAPI: OpenAPIChunker,
}
if data_type in chunker_classes:
chunker_class: type = chunker_classes[data_type]

View File

@@ -0,0 +1,42 @@
import hashlib
from io import StringIO
from urllib.parse import urlparse
import requests
import yaml
from embedchain.loaders.base_loader import BaseLoader
class OpenAPILoader(BaseLoader):
@staticmethod
def _get_file_content(content):
url = urlparse(content)
if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]:
raise ValueError("Not a valid URL.")
if url.scheme in ["http", "https"]:
response = requests.get(content)
response.raise_for_status()
return StringIO(response.text)
elif url.scheme == "file":
path = url.path
return open(path)
else:
return open(content)
@staticmethod
def load_data(content):
"""Load yaml file of openapi. Each pair is a document."""
data = []
file_path = content
data_content = []
with OpenAPILoader._get_file_content(content=content) as file:
yaml_data = yaml.load(file, Loader=yaml.Loader)
for i, (key, value) in enumerate(yaml_data.items()):
string_data = f"{key}: {value}"
meta_data = {"url": file_path, "row": i + 1}
data.append({"content": string_data, "meta_data": meta_data})
data_content.append(string_data)
doc_id = hashlib.sha256((content + ", ".join(data_content)).encode()).hexdigest()
return {"doc_id": doc_id, "data": data}

View File

@@ -27,6 +27,7 @@ class IndirectDataType(Enum):
IMAGES = "images"
UNSTRUCTURED = "unstructured"
JSON = "json"
OPENAPI = "openapi"
class SpecialDataType(Enum):
@@ -53,3 +54,4 @@ class DataType(Enum):
IMAGES = IndirectDataType.IMAGES.value
UNSTRUCTURED = IndirectDataType.UNSTRUCTURED.value
JSON = IndirectDataType.JSON.value
OPENAPI = IndirectDataType.OPENAPI.value

View File

@@ -115,6 +115,13 @@ def detect_datatype(source: Any) -> DataType:
"""
from urllib.parse import urlparse
import requests
import yaml
def is_openapi_yaml(yaml_content):
# currently the following two fields are required in openapi spec yaml config
return "openapi" in yaml_content and "info" in yaml_content
try:
if not isinstance(source, str):
raise ValueError("Source is not a string and thus cannot be a URL.")
@@ -155,6 +162,31 @@ def detect_datatype(source: Any) -> DataType:
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
return DataType.DOCX
if url.path.endswith(".yaml"):
try:
response = requests.get(source)
response.raise_for_status()
try:
yaml_content = yaml.safe_load(response.text)
except yaml.YAMLError as exc:
logging.error(f"Error parsing YAML: {exc}")
raise TypeError(f"Not a valid data type. Error loading YAML: {exc}")
if is_openapi_yaml(yaml_content):
logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
return DataType.OPENAPI
else:
logging.error(
f"Source of `{formatted_source}` does not contain all the required \
fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
)
raise TypeError(
"Not a valid data type. Check 'https://spec.openapis.org/oas/v3.1.0', \
make sure you have all the required fields in YAML config data"
)
except requests.exceptions.RequestException as e:
logging.error(f"Error fetching URL {formatted_source}: {e}")
if url.path.endswith(".json"):
logging.debug(f"Source of `{formatted_source}` detected as `json_file`.")
return DataType.JSON
@@ -198,6 +230,22 @@ def detect_datatype(source: Any) -> DataType:
logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
return DataType.XML
if source.endswith(".yaml"):
with open(source, "r") as file:
yaml_content = yaml.safe_load(file)
if is_openapi_yaml(yaml_content):
logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
return DataType.OPENAPI
else:
logging.error(
f"Source of `{formatted_source}` does not contain all the required \
fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
)
raise ValueError(
"Invalid YAML data. Check 'https://spec.openapis.org/oas/v3.1.0', \
make sure to add all the required params"
)
if source.endswith(".json"):
logging.debug(f"Source of `{formatted_source}` detected as `json`.")
return DataType.JSON