Add dry_run to add() (#545)
This commit is contained in:
@@ -3,7 +3,8 @@ import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config import AppConfig, AddConfig, ChunkerConfig
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
@@ -34,3 +35,28 @@ class TestApp(unittest.TestCase):
|
||||
data_type = "text"
|
||||
self.app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
|
||||
self.assertEqual(self.app.user_asks, [["https://example.com", data_type, {"meta": "meta-data"}]])
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_dry_run(self):
|
||||
"""
|
||||
Test that if dry_run == True then data chunks are returned.
|
||||
"""
|
||||
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0)
|
||||
# We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
|
||||
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
|
||||
|
||||
result = self.app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
|
||||
|
||||
chunks = result["chunks"]
|
||||
metadata = result["metadata"]
|
||||
count = result["count"]
|
||||
data_type = result["type"]
|
||||
|
||||
self.assertEqual(len(chunks), len(text))
|
||||
self.assertEqual(count, len(text))
|
||||
self.assertEqual(data_type, DataType.TEXT)
|
||||
for item in metadata:
|
||||
self.assertIsInstance(item, dict)
|
||||
self.assertIn(item["url"], "local")
|
||||
self.assertIn(item["data_type"], "text")
|
||||
|
||||
Reference in New Issue
Block a user