refactor: loader chunker typing (#324)

This commit is contained in:
cachho
2023-07-26 19:44:57 +02:00
committed by GitHub
parent a8552686b4
commit 55bfd7cafe
10 changed files with 34 additions and 9 deletions

View File

@@ -6,10 +6,12 @@ from dotenv import load_dotenv
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferMemory
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, ChatConfig, QueryConfig
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
from embedchain.data_formatter import DataFormatter
from embedchain.loaders.base_loader import BaseLoader
load_dotenv()
@@ -80,7 +82,7 @@ class EmbedChain:
metadata,
)
def load_and_embed(self, loader, chunker, src, metadata=None):
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None):
"""
Loads the data from the given URL, chunks it, and adds it to database.

View File

@@ -0,0 +1,9 @@
class BaseLoader:
def __init__(self):
pass
def load_data():
"""
Implemented by child classes
"""
pass

View File

@@ -4,8 +4,10 @@ from urllib.parse import urljoin, urlparse
import requests
from bs4 import BeautifulSoup
from embedchain.loaders.base_loader import BaseLoader
class DocsSiteLoader:
class DocsSiteLoader(BaseLoader):
def __init__(self):
self.visited_links = set()

View File

@@ -1,7 +1,9 @@
from langchain.document_loaders import Docx2txtLoader
from embedchain.loaders.base_loader import BaseLoader
class DocxFileLoader:
class DocxFileLoader(BaseLoader):
def load_data(self, url):
"""Load data from a .docx file."""
loader = Docx2txtLoader(url)

View File

@@ -1,4 +1,7 @@
class LocalQnaPairLoader:
from embedchain.loaders.base_loader import BaseLoader
class LocalQnaPairLoader(BaseLoader):
def load_data(self, content):
"""Load data from a local QnA pair."""
question, answer = content

View File

@@ -1,4 +1,7 @@
class LocalTextLoader:
from embedchain.loaders.base_loader import BaseLoader
class LocalTextLoader(BaseLoader):
def load_data(self, content):
"""Load data from a local text file."""
meta_data = {

View File

@@ -1,9 +1,10 @@
from langchain.document_loaders import PyPDFLoader
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
class PdfFileLoader:
class PdfFileLoader(BaseLoader):
def load_data(self, url):
"""Load data from a PDF file."""
loader = PyPDFLoader(url)

View File

@@ -4,11 +4,12 @@ import requests
from bs4 import BeautifulSoup
from bs4.builder import ParserRejectedMarkup
from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.web_page import WebPageLoader
from embedchain.utils import is_readable
class SitemapLoader:
class SitemapLoader(BaseLoader):
def load_data(self, sitemap_url):
"""
This method takes a sitemap URL as input and retrieves

View File

@@ -3,10 +3,11 @@ import logging
import requests
from bs4 import BeautifulSoup
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
class WebPageLoader:
class WebPageLoader(BaseLoader):
def load_data(self, url):
"""Load data from a web page."""
response = requests.get(url)

View File

@@ -1,9 +1,10 @@
from langchain.document_loaders import YoutubeLoader
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
class YoutubeVideoLoader:
class YoutubeVideoLoader(BaseLoader):
def load_data(self, url):
"""Load data from a Youtube video."""
loader = YoutubeLoader.from_youtube_url(url, add_video_info=True)