feat: add method - detect format / data_type (#380)

This commit is contained in:
cachho
2023-08-16 22:18:24 +02:00
committed by GitHub
parent f92e890aa1
commit 4c8876f032
18 changed files with 472 additions and 121 deletions

View File

@@ -4,6 +4,7 @@ import unittest
from embedchain.chunkers.text import TextChunker
from embedchain.config import ChunkerConfig
from embedchain.models.data_type import DataType
class TestTextChunker(unittest.TestCase):
@@ -15,6 +16,8 @@ class TestTextChunker(unittest.TestCase):
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
# Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text)
@@ -31,6 +34,8 @@ class TestTextChunker(unittest.TestCase):
chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len)
chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
# Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text)
@@ -46,6 +51,8 @@ class TestTextChunker(unittest.TestCase):
chunker = TextChunker(config=chunker_config)
# We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
# Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text)

View File

@@ -23,5 +23,14 @@ class TestApp(unittest.TestCase):
The Collection.add method from the chromadb library is mocked during this test to isolate the behavior of the
'add' method.
"""
self.app.add("web_page", "https://example.com", {"meta": "meta-data"})
self.assertEqual(self.app.user_asks, [["web_page", "https://example.com", {"meta": "meta-data"}]])
self.app.add("https://example.com", metadata={"meta": "meta-data"})
self.assertEqual(self.app.user_asks, [["https://example.com", "web_page", {"meta": "meta-data"}]])
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_add_forced_type(self):
"""
Test that you can also force a data_type with `add`.
"""
data_type = "text"
self.app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
self.assertEqual(self.app.user_asks, [["https://example.com", data_type, {"meta": "meta-data"}]])

View File

@@ -31,7 +31,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
app.add_local("text", knowledge)
app.add(knowledge, data_type="text")
app.query("What text did I give you?")
app.chat("What text did I give you?")

View File

@@ -0,0 +1,129 @@
import tempfile
import unittest
from unittest.mock import patch
from embedchain.models.data_type import DataType
from embedchain.utils import detect_datatype
class TestApp(unittest.TestCase):
"""Test that the datatype detection is working, based on the input."""
def test_detect_datatype_youtube(self):
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
self.assertEqual(detect_datatype("https://m.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
self.assertEqual(
detect_datatype("https://www.youtube-nocookie.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO
)
self.assertEqual(detect_datatype("https://vid.plus/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
self.assertEqual(detect_datatype("https://youtu.be/dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
def test_detect_datatype_local_file(self):
self.assertEqual(detect_datatype("file:///home/user/file.txt"), DataType.WEB_PAGE)
def test_detect_datatype_pdf(self):
self.assertEqual(detect_datatype("https://www.example.com/document.pdf"), DataType.PDF_FILE)
def test_detect_datatype_local_pdf(self):
self.assertEqual(detect_datatype("file:///home/user/document.pdf"), DataType.PDF_FILE)
def test_detect_datatype_xml(self):
self.assertEqual(detect_datatype("https://www.example.com/sitemap.xml"), DataType.SITEMAP)
def test_detect_datatype_local_xml(self):
self.assertEqual(detect_datatype("file:///home/user/sitemap.xml"), DataType.SITEMAP)
def test_detect_datatype_docx(self):
self.assertEqual(detect_datatype("https://www.example.com/document.docx"), DataType.DOCX)
def test_detect_datatype_local_docx(self):
self.assertEqual(detect_datatype("file:///home/user/document.docx"), DataType.DOCX)
@patch("os.path.isfile")
def test_detect_datatype_regular_filesystem_docx(self, mock_isfile):
with tempfile.NamedTemporaryFile(suffix=".docx", delete=True) as tmp:
mock_isfile.return_value = True
self.assertEqual(detect_datatype(tmp.name), DataType.DOCX)
def test_detect_datatype_docs_site(self):
self.assertEqual(detect_datatype("https://docs.example.com"), DataType.DOCS_SITE)
def test_detect_datatype_docs_sitein_path(self):
self.assertEqual(detect_datatype("https://www.example.com/docs/index.html"), DataType.DOCS_SITE)
self.assertNotEqual(detect_datatype("file:///var/www/docs/index.html"), DataType.DOCS_SITE) # NOT equal
def test_detect_datatype_web_page(self):
self.assertEqual(detect_datatype("https://nav.al/agi"), DataType.WEB_PAGE)
def test_detect_datatype_invalid_url(self):
self.assertEqual(detect_datatype("not a url"), DataType.TEXT)
def test_detect_datatype_qna_pair(self):
self.assertEqual(
detect_datatype(("Question?", "Answer. Content of the string is irrelevant.")), DataType.QNA_PAIR
) #
def test_detect_datatype_qna_pair_types(self):
"""Test that a QnA pair needs to be a tuple of length two, and both items have to be strings."""
with self.assertRaises(TypeError):
self.assertNotEqual(
detect_datatype(("How many planets are in our solar system?", 8)), DataType.QNA_PAIR
) # NOT equal
def test_detect_datatype_text(self):
self.assertEqual(detect_datatype("Just some text."), DataType.TEXT)
def test_detect_datatype_non_string_error(self):
"""Test type error if the value passed is not a string, and not a valid non-string data_type"""
with self.assertRaises(TypeError):
detect_datatype(["foo", "bar"])
@patch("os.path.isfile")
def test_detect_datatype_regular_filesystem_file_not_detected(self, mock_isfile):
"""Test error if a valid file is referenced, but it isn't a valid data_type"""
with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp:
mock_isfile.return_value = True
with self.assertRaises(ValueError):
detect_datatype(tmp.name)
def test_detect_datatype_regular_filesystem_no_file(self):
"""Test that if a filepath is not actually an existing file, it is not handled as a file path."""
self.assertEqual(detect_datatype("/var/not-an-existing-file.txt"), DataType.TEXT)
def test_doc_examples_quickstart(self):
"""Test examples used in the documentation."""
self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Elon_Musk"), DataType.WEB_PAGE)
self.assertEqual(detect_datatype("https://www.tesla.com/elon-musk"), DataType.WEB_PAGE)
def test_doc_examples_introduction(self):
"""Test examples used in the documentation."""
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=3qHkcs3kG44"), DataType.YOUTUBE_VIDEO)
self.assertEqual(
detect_datatype(
"https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf"
),
DataType.PDF_FILE,
)
self.assertEqual(detect_datatype("https://nav.al/feedback"), DataType.WEB_PAGE)
def test_doc_examples_app_types(self):
"""Test examples used in the documentation."""
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=Ff4fRgnuFgQ"), DataType.YOUTUBE_VIDEO)
self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Mark_Zuckerberg"), DataType.WEB_PAGE)
def test_doc_examples_configuration(self):
"""Test examples used in the documentation."""
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "wikipedia"])
import wikipedia
page = wikipedia.page("Albert Einstein")
# TODO: Add a wikipedia type, so wikipedia is a dependency and we don't need this slow test.
# (timings: import: 1.4s, fetch wiki: 0.7s)
self.assertEqual(detect_datatype(page.content), DataType.TEXT)
if __name__ == "__main__":
unittest.main()