feat: add local qna pair

This commit is contained in:
cachho
2023-06-23 11:43:33 +02:00
committed by Taranjeet Singh
parent df7614d349
commit ff2d5ce7fa
4 changed files with 52 additions and 2 deletions

View File

@@ -0,0 +1,16 @@
from embedchain.chunkers.base_chunker import BaseChunker
from langchain.text_splitter import RecursiveCharacterTextSplitter
TEXT_SPLITTER_CHUNK_PARAMS = {
"chunk_size": 300,
"chunk_overlap": 0,
"length_function": len,
}
class QnaPairChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
super().__init__(text_splitter)

View File

@@ -8,9 +8,11 @@ from langchain.embeddings.openai import OpenAIEmbeddings
from embedchain.loaders.youtube_video import YoutubeVideoLoader
from embedchain.loaders.pdf_file import PdfFileLoader
from embedchain.loaders.web_page import WebPageLoader
from embedchain.loaders_local.qna_pair import QnaPairLoader
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.vectordb.chroma_db import ChromaDB
load_dotenv()
@@ -46,7 +48,8 @@ class EmbedChain:
loaders = {
'youtube_video': YoutubeVideoLoader(),
'pdf_file': PdfFileLoader(),
'web_page': WebPageLoader()
'web_page': WebPageLoader(),
'qna_pair': QnaPairLoader()
}
if data_type in loaders:
return loaders[data_type]
@@ -64,7 +67,8 @@ class EmbedChain:
chunkers = {
'youtube_video': YoutubeVideoChunker(),
'pdf_file': PdfFileChunker(),
'web_page': WebPageChunker()
'web_page': WebPageChunker(),
'qna_pair': QnaPairChunker(),
}
if data_type in chunkers:
return chunkers[data_type]
@@ -85,6 +89,20 @@ class EmbedChain:
self.user_asks.append([data_type, url])
self.load_and_embed(loader, chunker, url)
def add_local(self, data_type, content):
"""
Adds the data you supply to the vector db.
Loads the data, chunks it, create embedding for each chunk
and then stores the embedding to vector database.
:param data_type: The type of the data to add.
:param content: The local data. Refer to the `README` for formatting.
"""
loader = self._get_loader(data_type)
chunker = self._get_chunker(data_type)
self.user_asks.append([data_type, content])
self.load_and_embed(loader, chunker, content)
def load_and_embed(self, loader, chunker, url):
"""
Loads the data from the given URL, chunks it, and adds it to the database.

View File

View File

@@ -0,0 +1,16 @@
from embedchain.utils import markdown_to_plaintext
class QnaPairLoader:
def load_data(self, content):
question, answer = content
answer = markdown_to_plaintext(answer)
content = f"Q: {question}\nA: {answer}"
meta_data = {
"url": "local",
}
return [{
"content": content,
"meta_data": meta_data,
}]