feat: add local qna pair

This commit is contained in:
cachho
2023-06-23 11:43:33 +02:00
committed by Taranjeet Singh
parent df7614d349
commit ff2d5ce7fa
4 changed files with 52 additions and 2 deletions

View File

@@ -0,0 +1,16 @@
from embedchain.chunkers.base_chunker import BaseChunker
from langchain.text_splitter import RecursiveCharacterTextSplitter
TEXT_SPLITTER_CHUNK_PARAMS = {
"chunk_size": 300,
"chunk_overlap": 0,
"length_function": len,
}
class QnaPairChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
super().__init__(text_splitter)

View File

@@ -8,9 +8,11 @@ from langchain.embeddings.openai import OpenAIEmbeddings
from embedchain.loaders.youtube_video import YoutubeVideoLoader from embedchain.loaders.youtube_video import YoutubeVideoLoader
from embedchain.loaders.pdf_file import PdfFileLoader from embedchain.loaders.pdf_file import PdfFileLoader
from embedchain.loaders.web_page import WebPageLoader from embedchain.loaders.web_page import WebPageLoader
from embedchain.loaders_local.qna_pair import QnaPairLoader
from embedchain.chunkers.youtube_video import YoutubeVideoChunker from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.web_page import WebPageChunker from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.vectordb.chroma_db import ChromaDB from embedchain.vectordb.chroma_db import ChromaDB
load_dotenv() load_dotenv()
@@ -46,7 +48,8 @@ class EmbedChain:
loaders = { loaders = {
'youtube_video': YoutubeVideoLoader(), 'youtube_video': YoutubeVideoLoader(),
'pdf_file': PdfFileLoader(), 'pdf_file': PdfFileLoader(),
'web_page': WebPageLoader() 'web_page': WebPageLoader(),
'qna_pair': QnaPairLoader()
} }
if data_type in loaders: if data_type in loaders:
return loaders[data_type] return loaders[data_type]
@@ -64,7 +67,8 @@ class EmbedChain:
chunkers = { chunkers = {
'youtube_video': YoutubeVideoChunker(), 'youtube_video': YoutubeVideoChunker(),
'pdf_file': PdfFileChunker(), 'pdf_file': PdfFileChunker(),
'web_page': WebPageChunker() 'web_page': WebPageChunker(),
'qna_pair': QnaPairChunker(),
} }
if data_type in chunkers: if data_type in chunkers:
return chunkers[data_type] return chunkers[data_type]
@@ -85,6 +89,20 @@ class EmbedChain:
self.user_asks.append([data_type, url]) self.user_asks.append([data_type, url])
self.load_and_embed(loader, chunker, url) self.load_and_embed(loader, chunker, url)
def add_local(self, data_type, content):
"""
Adds the data you supply to the vector db.
Loads the data, chunks it, create embedding for each chunk
and then stores the embedding to vector database.
:param data_type: The type of the data to add.
:param content: The local data. Refer to the `README` for formatting.
"""
loader = self._get_loader(data_type)
chunker = self._get_chunker(data_type)
self.user_asks.append([data_type, content])
self.load_and_embed(loader, chunker, content)
def load_and_embed(self, loader, chunker, url): def load_and_embed(self, loader, chunker, url):
""" """
Loads the data from the given URL, chunks it, and adds it to the database. Loads the data from the given URL, chunks it, and adds it to the database.

View File

View File

@@ -0,0 +1,16 @@
from embedchain.utils import markdown_to_plaintext
class QnaPairLoader:
def load_data(self, content):
question, answer = content
answer = markdown_to_plaintext(answer)
content = f"Q: {question}\nA: {answer}"
meta_data = {
"url": "local",
}
return [{
"content": content,
"meta_data": meta_data,
}]