diff --git a/docs/advanced/data_types.mdx b/docs/advanced/data_types.mdx index 79f09363..2e8043fd 100644 --- a/docs/advanced/data_types.mdx +++ b/docs/advanced/data_types.mdx @@ -73,6 +73,15 @@ app.add('https://example.com/content/intro.docx', data_type="docx") app.add('content/intro.docx', data_type="docx") ``` +### CSV file + +To add any csv file, use the data_type as `csv`. `csv` allows remote urls and conventional file paths. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg: + +```python +app.add('https://example.com/content/sheet.csv', data_type="csv") +app.add('content/sheet.csv', data_type="csv") +``` + ### Code documentation website loader To add any code documentation website as a loader, use the data_type as `docs_site`. Eg: diff --git a/embedchain/chunkers/table.py b/embedchain/chunkers/table.py new file mode 100644 index 00000000..5f493a9c --- /dev/null +++ b/embedchain/chunkers/table.py @@ -0,0 +1,20 @@ +from typing import Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.AddConfig import ChunkerConfig + + +class TableChunker(BaseChunker): + """Chunker for tables, for instance csv, google sheets or databases.""" + + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) + super().__init__(text_splitter) diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 4691e56d..60249d7b 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -3,11 +3,13 @@ from embedchain.chunkers.docx_file import DocxFileChunker from embedchain.chunkers.notion import NotionChunker from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.qna_pair import QnaPairChunker +from embedchain.chunkers.table import TableChunker from embedchain.chunkers.text import TextChunker from embedchain.chunkers.web_page import WebPageChunker from embedchain.chunkers.youtube_video import YoutubeVideoChunker from embedchain.config import AddConfig from embedchain.helper_classes.json_serializable import JSONSerializable +from embedchain.loaders.csv import CsvLoader from embedchain.loaders.docs_site_loader import DocsSiteLoader from embedchain.loaders.docx_file import DocxFileLoader from embedchain.loaders.local_qna_pair import LocalQnaPairLoader @@ -47,6 +49,7 @@ class DataFormatter(JSONSerializable): DataType.DOCX: DocxFileLoader, DataType.SITEMAP: SitemapLoader, DataType.DOCS_SITE: DocsSiteLoader, + DataType.CSV: CsvLoader, } lazy_loaders = {DataType.NOTION} if data_type in loaders: @@ -81,6 +84,7 @@ class DataFormatter(JSONSerializable): DataType.WEB_PAGE: WebPageChunker, DataType.DOCS_SITE: DocsSiteChunker, DataType.NOTION: NotionChunker, + DataType.CSV: TableChunker, } if data_type in chunker_classes: chunker_class = chunker_classes[data_type] diff --git a/embedchain/loaders/csv.py b/embedchain/loaders/csv.py new file mode 100644 index 00000000..6c5eb7f0 --- /dev/null +++ b/embedchain/loaders/csv.py @@ -0,0 +1,46 @@ +import csv +from io import StringIO +from urllib.parse import urlparse + +import requests + +from embedchain.loaders.base_loader import BaseLoader + + +class CsvLoader(BaseLoader): + @staticmethod + def _detect_delimiter(first_line): + delimiters = [",", "\t", ";", "|"] + counts = {delimiter: first_line.count(delimiter) for delimiter in delimiters} + return max(counts, key=counts.get) + + @staticmethod + def _get_file_content(content): + url = urlparse(content) + if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]: + raise ValueError("Not a valid URL.") + + if url.scheme in ["http", "https"]: + response = requests.get(content) + response.raise_for_status() + return StringIO(response.text) + elif url.scheme == "file": + path = url.path + return open(path, newline="") # Open the file using the path from the URI + else: + return open(content, newline="") # Treat content as a regular file path + + @staticmethod + def load_data(content): + """Load a csv file with headers. Each line is a document""" + result = [] + + with CsvLoader._get_file_content(content) as file: + first_line = file.readline() + delimiter = CsvLoader._detect_delimiter(first_line) + file.seek(0) # Reset the file pointer to the start + reader = csv.DictReader(file, delimiter=delimiter) + for i, row in enumerate(reader): + line = ", ".join([f"{field}: {value}" for field, value in row.items()]) + result.append({"content": line, "meta_data": {"url": content, "row": i + 1}}) + return result diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index e753c8e2..7d8d5232 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -11,3 +11,4 @@ class DataType(Enum): TEXT = "text" QNA_PAIR = "qna_pair" NOTION = "notion" + CSV = "csv" diff --git a/embedchain/utils.py b/embedchain/utils.py index b04b0957..43c5ef69 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -147,6 +147,10 @@ def detect_datatype(source: Any) -> DataType: logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.") return DataType.SITEMAP + if url.path.endswith(".csv"): + logging.debug(f"Source of `{formatted_source}` detected as `csv`.") + return DataType.CSV + if url.path.endswith(".docx"): logging.debug(f"Source of `{formatted_source}` detected as `docx`.") return DataType.DOCX @@ -182,6 +186,10 @@ def detect_datatype(source: Any) -> DataType: logging.debug(f"Source of `{formatted_source}` detected as `docx`.") return DataType.DOCX + if source.endswith(".csv"): + logging.debug(f"Source of `{formatted_source}` detected as `csv`.") + return DataType.CSV + # If the source is a valid file, that's not detectable as a type, an error is raised. # It does not fallback to text. raise ValueError( diff --git a/tests/loaders/test_csv.py b/tests/loaders/test_csv.py new file mode 100644 index 00000000..d004c510 --- /dev/null +++ b/tests/loaders/test_csv.py @@ -0,0 +1,84 @@ +import csv +import os +import pathlib +import tempfile + +import pytest + +from embedchain.loaders.csv import CsvLoader + + +@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"]) +def test_load_data(delimiter): + """ + Test csv loader + + Tests that file is loaded, metadata is correct and content is correct + """ + # Creating temporary CSV file + with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile: + writer = csv.writer(tmpfile, delimiter=delimiter) + writer.writerow(["Name", "Age", "Occupation"]) + writer.writerow(["Alice", "28", "Engineer"]) + writer.writerow(["Bob", "35", "Doctor"]) + writer.writerow(["Charlie", "22", "Student"]) + + tmpfile.seek(0) + filename = tmpfile.name + + # Loading CSV using CsvLoader + loader = CsvLoader() + result = loader.load_data(filename) + + # Assertions + assert len(result) == 3 + assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer" + assert result[0]["meta_data"]["url"] == filename + assert result[0]["meta_data"]["row"] == 1 + assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor" + assert result[1]["meta_data"]["url"] == filename + assert result[1]["meta_data"]["row"] == 2 + assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student" + assert result[2]["meta_data"]["url"] == filename + assert result[2]["meta_data"]["row"] == 3 + + # Cleaning up the temporary file + os.unlink(filename) + + +@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"]) +def test_load_data_with_file_uri(delimiter): + """ + Test csv loader with file URI + + Tests that file is loaded, metadata is correct and content is correct + """ + # Creating temporary CSV file + with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile: + writer = csv.writer(tmpfile, delimiter=delimiter) + writer.writerow(["Name", "Age", "Occupation"]) + writer.writerow(["Alice", "28", "Engineer"]) + writer.writerow(["Bob", "35", "Doctor"]) + writer.writerow(["Charlie", "22", "Student"]) + + tmpfile.seek(0) + filename = pathlib.Path(tmpfile.name).as_uri() # Convert path to file URI + + # Loading CSV using CsvLoader + loader = CsvLoader() + result = loader.load_data(filename) + + # Assertions + assert len(result) == 3 + assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer" + assert result[0]["meta_data"]["url"] == filename + assert result[0]["meta_data"]["row"] == 1 + assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor" + assert result[1]["meta_data"]["url"] == filename + assert result[1]["meta_data"]["row"] == 2 + assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student" + assert result[2]["meta_data"]["url"] == filename + assert result[2]["meta_data"]["row"] == 3 + + # Cleaning up the temporary file + os.unlink(tmpfile.name)