feat: csv loader (#470)

Co-authored-by: Taranjeet Singh <reachtotj@gmail.com>
This commit is contained in:
cachho
2023-09-05 10:18:03 +02:00
committed by GitHub
parent 344e7470f6
commit bd595f84e8
7 changed files with 172 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
class TableChunker(BaseChunker):
"""Chunker for tables, for instance csv, google sheets or databases."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -3,11 +3,13 @@ from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.notion import NotionChunker
from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.chunkers.table import TableChunker
from embedchain.chunkers.text import TextChunker
from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.config import AddConfig
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.loaders.csv import CsvLoader
from embedchain.loaders.docs_site_loader import DocsSiteLoader
from embedchain.loaders.docx_file import DocxFileLoader
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
@@ -47,6 +49,7 @@ class DataFormatter(JSONSerializable):
DataType.DOCX: DocxFileLoader,
DataType.SITEMAP: SitemapLoader,
DataType.DOCS_SITE: DocsSiteLoader,
DataType.CSV: CsvLoader,
}
lazy_loaders = {DataType.NOTION}
if data_type in loaders:
@@ -81,6 +84,7 @@ class DataFormatter(JSONSerializable):
DataType.WEB_PAGE: WebPageChunker,
DataType.DOCS_SITE: DocsSiteChunker,
DataType.NOTION: NotionChunker,
DataType.CSV: TableChunker,
}
if data_type in chunker_classes:
chunker_class = chunker_classes[data_type]

46
embedchain/loaders/csv.py Normal file
View File

@@ -0,0 +1,46 @@
import csv
from io import StringIO
from urllib.parse import urlparse
import requests
from embedchain.loaders.base_loader import BaseLoader
class CsvLoader(BaseLoader):
@staticmethod
def _detect_delimiter(first_line):
delimiters = [",", "\t", ";", "|"]
counts = {delimiter: first_line.count(delimiter) for delimiter in delimiters}
return max(counts, key=counts.get)
@staticmethod
def _get_file_content(content):
url = urlparse(content)
if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]:
raise ValueError("Not a valid URL.")
if url.scheme in ["http", "https"]:
response = requests.get(content)
response.raise_for_status()
return StringIO(response.text)
elif url.scheme == "file":
path = url.path
return open(path, newline="") # Open the file using the path from the URI
else:
return open(content, newline="") # Treat content as a regular file path
@staticmethod
def load_data(content):
"""Load a csv file with headers. Each line is a document"""
result = []
with CsvLoader._get_file_content(content) as file:
first_line = file.readline()
delimiter = CsvLoader._detect_delimiter(first_line)
file.seek(0) # Reset the file pointer to the start
reader = csv.DictReader(file, delimiter=delimiter)
for i, row in enumerate(reader):
line = ", ".join([f"{field}: {value}" for field, value in row.items()])
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
return result

View File

@@ -11,3 +11,4 @@ class DataType(Enum):
TEXT = "text"
QNA_PAIR = "qna_pair"
NOTION = "notion"
CSV = "csv"

View File

@@ -147,6 +147,10 @@ def detect_datatype(source: Any) -> DataType:
logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
return DataType.SITEMAP
if url.path.endswith(".csv"):
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
return DataType.CSV
if url.path.endswith(".docx"):
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
return DataType.DOCX
@@ -182,6 +186,10 @@ def detect_datatype(source: Any) -> DataType:
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
return DataType.DOCX
if source.endswith(".csv"):
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
return DataType.CSV
# If the source is a valid file, that's not detectable as a type, an error is raised.
# It does not fallback to text.
raise ValueError(