refactor: loader chunker typing (#324)
This commit is contained in:
@@ -6,10 +6,12 @@ from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, ChatConfig, QueryConfig
|
||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -80,7 +82,7 @@ class EmbedChain:
|
||||
metadata,
|
||||
)
|
||||
|
||||
def load_and_embed(self, loader, chunker, src, metadata=None):
|
||||
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None):
|
||||
"""
|
||||
Loads the data from the given URL, chunks it, and adds it to database.
|
||||
|
||||
|
||||
9
embedchain/loaders/base_loader.py
Normal file
9
embedchain/loaders/base_loader.py
Normal file
@@ -0,0 +1,9 @@
|
||||
class BaseLoader:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_data():
|
||||
"""
|
||||
Implemented by child classes
|
||||
"""
|
||||
pass
|
||||
@@ -4,8 +4,10 @@ from urllib.parse import urljoin, urlparse
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
class DocsSiteLoader:
|
||||
|
||||
class DocsSiteLoader(BaseLoader):
|
||||
def __init__(self):
|
||||
self.visited_links = set()
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from langchain.document_loaders import Docx2txtLoader
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
class DocxFileLoader:
|
||||
|
||||
class DocxFileLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a .docx file."""
|
||||
loader = Docx2txtLoader(url)
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
class LocalQnaPairLoader:
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
class LocalQnaPairLoader(BaseLoader):
|
||||
def load_data(self, content):
|
||||
"""Load data from a local QnA pair."""
|
||||
question, answer = content
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
class LocalTextLoader:
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
class LocalTextLoader(BaseLoader):
|
||||
def load_data(self, content):
|
||||
"""Load data from a local text file."""
|
||||
meta_data = {
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
class PdfFileLoader:
|
||||
class PdfFileLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a PDF file."""
|
||||
loader = PyPDFLoader(url)
|
||||
|
||||
@@ -4,11 +4,12 @@ import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4.builder import ParserRejectedMarkup
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
from embedchain.utils import is_readable
|
||||
|
||||
|
||||
class SitemapLoader:
|
||||
class SitemapLoader(BaseLoader):
|
||||
def load_data(self, sitemap_url):
|
||||
"""
|
||||
This method takes a sitemap URL as input and retrieves
|
||||
|
||||
@@ -3,10 +3,11 @@ import logging
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
class WebPageLoader:
|
||||
class WebPageLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a web page."""
|
||||
response = requests.get(url)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from langchain.document_loaders import YoutubeLoader
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
class YoutubeVideoLoader:
|
||||
class YoutubeVideoLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a Youtube video."""
|
||||
loader = YoutubeLoader.from_youtube_url(url, add_video_info=True)
|
||||
|
||||
Reference in New Issue
Block a user