diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 7b907c36..a552aa89 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -6,10 +6,12 @@ from dotenv import load_dotenv from langchain.docstore.document import Document from langchain.memory import ConversationBufferMemory +from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config import AddConfig, ChatConfig, QueryConfig from embedchain.config.apps.BaseAppConfig import BaseAppConfig from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE from embedchain.data_formatter import DataFormatter +from embedchain.loaders.base_loader import BaseLoader load_dotenv() @@ -80,7 +82,7 @@ class EmbedChain: metadata, ) - def load_and_embed(self, loader, chunker, src, metadata=None): + def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None): """ Loads the data from the given URL, chunks it, and adds it to database. diff --git a/embedchain/loaders/base_loader.py b/embedchain/loaders/base_loader.py new file mode 100644 index 00000000..83048518 --- /dev/null +++ b/embedchain/loaders/base_loader.py @@ -0,0 +1,9 @@ +class BaseLoader: + def __init__(self): + pass + + def load_data(): + """ + Implemented by child classes + """ + pass diff --git a/embedchain/loaders/docs_site_loader.py b/embedchain/loaders/docs_site_loader.py index 404ba8b1..63d6bba0 100644 --- a/embedchain/loaders/docs_site_loader.py +++ b/embedchain/loaders/docs_site_loader.py @@ -4,8 +4,10 @@ from urllib.parse import urljoin, urlparse import requests from bs4 import BeautifulSoup +from embedchain.loaders.base_loader import BaseLoader -class DocsSiteLoader: + +class DocsSiteLoader(BaseLoader): def __init__(self): self.visited_links = set() diff --git a/embedchain/loaders/docx_file.py b/embedchain/loaders/docx_file.py index d88198be..9a304939 100644 --- a/embedchain/loaders/docx_file.py +++ b/embedchain/loaders/docx_file.py @@ -1,7 +1,9 @@ from langchain.document_loaders import Docx2txtLoader +from embedchain.loaders.base_loader import BaseLoader -class DocxFileLoader: + +class DocxFileLoader(BaseLoader): def load_data(self, url): """Load data from a .docx file.""" loader = Docx2txtLoader(url) diff --git a/embedchain/loaders/local_qna_pair.py b/embedchain/loaders/local_qna_pair.py index ceaf9779..673c009e 100644 --- a/embedchain/loaders/local_qna_pair.py +++ b/embedchain/loaders/local_qna_pair.py @@ -1,4 +1,7 @@ -class LocalQnaPairLoader: +from embedchain.loaders.base_loader import BaseLoader + + +class LocalQnaPairLoader(BaseLoader): def load_data(self, content): """Load data from a local QnA pair.""" question, answer = content diff --git a/embedchain/loaders/local_text.py b/embedchain/loaders/local_text.py index c61a6d85..779b2036 100644 --- a/embedchain/loaders/local_text.py +++ b/embedchain/loaders/local_text.py @@ -1,4 +1,7 @@ -class LocalTextLoader: +from embedchain.loaders.base_loader import BaseLoader + + +class LocalTextLoader(BaseLoader): def load_data(self, content): """Load data from a local text file.""" meta_data = { diff --git a/embedchain/loaders/pdf_file.py b/embedchain/loaders/pdf_file.py index 888193ba..06e88ca9 100644 --- a/embedchain/loaders/pdf_file.py +++ b/embedchain/loaders/pdf_file.py @@ -1,9 +1,10 @@ from langchain.document_loaders import PyPDFLoader +from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string -class PdfFileLoader: +class PdfFileLoader(BaseLoader): def load_data(self, url): """Load data from a PDF file.""" loader = PyPDFLoader(url) diff --git a/embedchain/loaders/sitemap.py b/embedchain/loaders/sitemap.py index 09dd5f53..4442dec5 100644 --- a/embedchain/loaders/sitemap.py +++ b/embedchain/loaders/sitemap.py @@ -4,11 +4,12 @@ import requests from bs4 import BeautifulSoup from bs4.builder import ParserRejectedMarkup +from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.web_page import WebPageLoader from embedchain.utils import is_readable -class SitemapLoader: +class SitemapLoader(BaseLoader): def load_data(self, sitemap_url): """ This method takes a sitemap URL as input and retrieves diff --git a/embedchain/loaders/web_page.py b/embedchain/loaders/web_page.py index cb3ba181..417898ea 100644 --- a/embedchain/loaders/web_page.py +++ b/embedchain/loaders/web_page.py @@ -3,10 +3,11 @@ import logging import requests from bs4 import BeautifulSoup +from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string -class WebPageLoader: +class WebPageLoader(BaseLoader): def load_data(self, url): """Load data from a web page.""" response = requests.get(url) diff --git a/embedchain/loaders/youtube_video.py b/embedchain/loaders/youtube_video.py index 59790b57..5cc6cc0d 100644 --- a/embedchain/loaders/youtube_video.py +++ b/embedchain/loaders/youtube_video.py @@ -1,9 +1,10 @@ from langchain.document_loaders import YoutubeLoader +from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string -class YoutubeVideoLoader: +class YoutubeVideoLoader(BaseLoader): def load_data(self, url): """Load data from a Youtube video.""" loader = YoutubeLoader.from_youtube_url(url, add_video_info=True)