feat: csv loader (#470)
Co-authored-by: Taranjeet Singh <reachtotj@gmail.com>
This commit is contained in:
@@ -73,6 +73,15 @@ app.add('https://example.com/content/intro.docx', data_type="docx")
|
|||||||
app.add('content/intro.docx', data_type="docx")
|
app.add('content/intro.docx', data_type="docx")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### CSV file
|
||||||
|
|
||||||
|
To add any csv file, use the data_type as `csv`. `csv` allows remote urls and conventional file paths. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg:
|
||||||
|
|
||||||
|
```python
|
||||||
|
app.add('https://example.com/content/sheet.csv', data_type="csv")
|
||||||
|
app.add('content/sheet.csv', data_type="csv")
|
||||||
|
```
|
||||||
|
|
||||||
### Code documentation website loader
|
### Code documentation website loader
|
||||||
|
|
||||||
To add any code documentation website as a loader, use the data_type as `docs_site`. Eg:
|
To add any code documentation website as a loader, use the data_type as `docs_site`. Eg:
|
||||||
|
|||||||
20
embedchain/chunkers/table.py
Normal file
20
embedchain/chunkers/table.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TableChunker(BaseChunker):
|
||||||
|
"""Chunker for tables, for instance csv, google sheets or databases."""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||||
|
if config is None:
|
||||||
|
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=config.chunk_size,
|
||||||
|
chunk_overlap=config.chunk_overlap,
|
||||||
|
length_function=config.length_function,
|
||||||
|
)
|
||||||
|
super().__init__(text_splitter)
|
||||||
@@ -3,11 +3,13 @@ from embedchain.chunkers.docx_file import DocxFileChunker
|
|||||||
from embedchain.chunkers.notion import NotionChunker
|
from embedchain.chunkers.notion import NotionChunker
|
||||||
from embedchain.chunkers.pdf_file import PdfFileChunker
|
from embedchain.chunkers.pdf_file import PdfFileChunker
|
||||||
from embedchain.chunkers.qna_pair import QnaPairChunker
|
from embedchain.chunkers.qna_pair import QnaPairChunker
|
||||||
|
from embedchain.chunkers.table import TableChunker
|
||||||
from embedchain.chunkers.text import TextChunker
|
from embedchain.chunkers.text import TextChunker
|
||||||
from embedchain.chunkers.web_page import WebPageChunker
|
from embedchain.chunkers.web_page import WebPageChunker
|
||||||
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
|
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
|
||||||
from embedchain.config import AddConfig
|
from embedchain.config import AddConfig
|
||||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||||
|
from embedchain.loaders.csv import CsvLoader
|
||||||
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
||||||
from embedchain.loaders.docx_file import DocxFileLoader
|
from embedchain.loaders.docx_file import DocxFileLoader
|
||||||
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
|
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
|
||||||
@@ -47,6 +49,7 @@ class DataFormatter(JSONSerializable):
|
|||||||
DataType.DOCX: DocxFileLoader,
|
DataType.DOCX: DocxFileLoader,
|
||||||
DataType.SITEMAP: SitemapLoader,
|
DataType.SITEMAP: SitemapLoader,
|
||||||
DataType.DOCS_SITE: DocsSiteLoader,
|
DataType.DOCS_SITE: DocsSiteLoader,
|
||||||
|
DataType.CSV: CsvLoader,
|
||||||
}
|
}
|
||||||
lazy_loaders = {DataType.NOTION}
|
lazy_loaders = {DataType.NOTION}
|
||||||
if data_type in loaders:
|
if data_type in loaders:
|
||||||
@@ -81,6 +84,7 @@ class DataFormatter(JSONSerializable):
|
|||||||
DataType.WEB_PAGE: WebPageChunker,
|
DataType.WEB_PAGE: WebPageChunker,
|
||||||
DataType.DOCS_SITE: DocsSiteChunker,
|
DataType.DOCS_SITE: DocsSiteChunker,
|
||||||
DataType.NOTION: NotionChunker,
|
DataType.NOTION: NotionChunker,
|
||||||
|
DataType.CSV: TableChunker,
|
||||||
}
|
}
|
||||||
if data_type in chunker_classes:
|
if data_type in chunker_classes:
|
||||||
chunker_class = chunker_classes[data_type]
|
chunker_class = chunker_classes[data_type]
|
||||||
|
|||||||
46
embedchain/loaders/csv.py
Normal file
46
embedchain/loaders/csv.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import csv
|
||||||
|
from io import StringIO
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
|
|
||||||
|
|
||||||
|
class CsvLoader(BaseLoader):
|
||||||
|
@staticmethod
|
||||||
|
def _detect_delimiter(first_line):
|
||||||
|
delimiters = [",", "\t", ";", "|"]
|
||||||
|
counts = {delimiter: first_line.count(delimiter) for delimiter in delimiters}
|
||||||
|
return max(counts, key=counts.get)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_file_content(content):
|
||||||
|
url = urlparse(content)
|
||||||
|
if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]:
|
||||||
|
raise ValueError("Not a valid URL.")
|
||||||
|
|
||||||
|
if url.scheme in ["http", "https"]:
|
||||||
|
response = requests.get(content)
|
||||||
|
response.raise_for_status()
|
||||||
|
return StringIO(response.text)
|
||||||
|
elif url.scheme == "file":
|
||||||
|
path = url.path
|
||||||
|
return open(path, newline="") # Open the file using the path from the URI
|
||||||
|
else:
|
||||||
|
return open(content, newline="") # Treat content as a regular file path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_data(content):
|
||||||
|
"""Load a csv file with headers. Each line is a document"""
|
||||||
|
result = []
|
||||||
|
|
||||||
|
with CsvLoader._get_file_content(content) as file:
|
||||||
|
first_line = file.readline()
|
||||||
|
delimiter = CsvLoader._detect_delimiter(first_line)
|
||||||
|
file.seek(0) # Reset the file pointer to the start
|
||||||
|
reader = csv.DictReader(file, delimiter=delimiter)
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
line = ", ".join([f"{field}: {value}" for field, value in row.items()])
|
||||||
|
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
|
||||||
|
return result
|
||||||
@@ -11,3 +11,4 @@ class DataType(Enum):
|
|||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
QNA_PAIR = "qna_pair"
|
QNA_PAIR = "qna_pair"
|
||||||
NOTION = "notion"
|
NOTION = "notion"
|
||||||
|
CSV = "csv"
|
||||||
|
|||||||
@@ -147,6 +147,10 @@ def detect_datatype(source: Any) -> DataType:
|
|||||||
logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
|
logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
|
||||||
return DataType.SITEMAP
|
return DataType.SITEMAP
|
||||||
|
|
||||||
|
if url.path.endswith(".csv"):
|
||||||
|
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
|
||||||
|
return DataType.CSV
|
||||||
|
|
||||||
if url.path.endswith(".docx"):
|
if url.path.endswith(".docx"):
|
||||||
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||||
return DataType.DOCX
|
return DataType.DOCX
|
||||||
@@ -182,6 +186,10 @@ def detect_datatype(source: Any) -> DataType:
|
|||||||
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||||
return DataType.DOCX
|
return DataType.DOCX
|
||||||
|
|
||||||
|
if source.endswith(".csv"):
|
||||||
|
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
|
||||||
|
return DataType.CSV
|
||||||
|
|
||||||
# If the source is a valid file, that's not detectable as a type, an error is raised.
|
# If the source is a valid file, that's not detectable as a type, an error is raised.
|
||||||
# It does not fallback to text.
|
# It does not fallback to text.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
84
tests/loaders/test_csv.py
Normal file
84
tests/loaders/test_csv.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from embedchain.loaders.csv import CsvLoader
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
|
||||||
|
def test_load_data(delimiter):
|
||||||
|
"""
|
||||||
|
Test csv loader
|
||||||
|
|
||||||
|
Tests that file is loaded, metadata is correct and content is correct
|
||||||
|
"""
|
||||||
|
# Creating temporary CSV file
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
|
||||||
|
writer = csv.writer(tmpfile, delimiter=delimiter)
|
||||||
|
writer.writerow(["Name", "Age", "Occupation"])
|
||||||
|
writer.writerow(["Alice", "28", "Engineer"])
|
||||||
|
writer.writerow(["Bob", "35", "Doctor"])
|
||||||
|
writer.writerow(["Charlie", "22", "Student"])
|
||||||
|
|
||||||
|
tmpfile.seek(0)
|
||||||
|
filename = tmpfile.name
|
||||||
|
|
||||||
|
# Loading CSV using CsvLoader
|
||||||
|
loader = CsvLoader()
|
||||||
|
result = loader.load_data(filename)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
|
||||||
|
assert result[0]["meta_data"]["url"] == filename
|
||||||
|
assert result[0]["meta_data"]["row"] == 1
|
||||||
|
assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
|
||||||
|
assert result[1]["meta_data"]["url"] == filename
|
||||||
|
assert result[1]["meta_data"]["row"] == 2
|
||||||
|
assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
|
||||||
|
assert result[2]["meta_data"]["url"] == filename
|
||||||
|
assert result[2]["meta_data"]["row"] == 3
|
||||||
|
|
||||||
|
# Cleaning up the temporary file
|
||||||
|
os.unlink(filename)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
|
||||||
|
def test_load_data_with_file_uri(delimiter):
|
||||||
|
"""
|
||||||
|
Test csv loader with file URI
|
||||||
|
|
||||||
|
Tests that file is loaded, metadata is correct and content is correct
|
||||||
|
"""
|
||||||
|
# Creating temporary CSV file
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
|
||||||
|
writer = csv.writer(tmpfile, delimiter=delimiter)
|
||||||
|
writer.writerow(["Name", "Age", "Occupation"])
|
||||||
|
writer.writerow(["Alice", "28", "Engineer"])
|
||||||
|
writer.writerow(["Bob", "35", "Doctor"])
|
||||||
|
writer.writerow(["Charlie", "22", "Student"])
|
||||||
|
|
||||||
|
tmpfile.seek(0)
|
||||||
|
filename = pathlib.Path(tmpfile.name).as_uri() # Convert path to file URI
|
||||||
|
|
||||||
|
# Loading CSV using CsvLoader
|
||||||
|
loader = CsvLoader()
|
||||||
|
result = loader.load_data(filename)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
|
||||||
|
assert result[0]["meta_data"]["url"] == filename
|
||||||
|
assert result[0]["meta_data"]["row"] == 1
|
||||||
|
assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
|
||||||
|
assert result[1]["meta_data"]["url"] == filename
|
||||||
|
assert result[1]["meta_data"]["row"] == 2
|
||||||
|
assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
|
||||||
|
assert result[2]["meta_data"]["url"] == filename
|
||||||
|
assert result[2]["meta_data"]["row"] == 3
|
||||||
|
|
||||||
|
# Cleaning up the temporary file
|
||||||
|
os.unlink(tmpfile.name)
|
||||||
Reference in New Issue
Block a user