Add support for image dataset (#571)

Co-authored-by: Rupesh Bansal <rupeshbansal@Shankars-MacBook-Air.local>
This commit is contained in:
Rupesh Bansal
2023-10-04 09:50:40 +05:30
committed by GitHub
parent 55e9a1cbd6
commit d0af018b8d
19 changed files with 498 additions and 31 deletions

View File

@@ -0,0 +1,72 @@
import unittest
from embedchain.chunkers.images import ImagesChunker
from embedchain.config import ChunkerConfig
from embedchain.models.data_type import DataType
class TestImageChunker(unittest.TestCase):
def test_chunks(self):
"""
Test the chunks generated by TextChunker.
# TODO: Not a very precise test.
"""
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
chunker = ImagesChunker(config=chunker_config)
# Data type must be set manually in the test
chunker.set_data_type(DataType.IMAGES)
image_path = "./tmp/image.jpeg"
result = chunker.create_chunks(MockLoader(), image_path)
expected_chunks = {'doc_id': '123',
'documents': [image_path],
'embeddings': ['embedding'],
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
self.assertEqual(expected_chunks, result)
def test_chunks_with_default_config(self):
"""
Test the chunks generated by ImageChunker with default config.
"""
chunker = ImagesChunker()
# Data type must be set manually in the test
chunker.set_data_type(DataType.IMAGES)
image_path = "./tmp/image.jpeg"
result = chunker.create_chunks(MockLoader(), image_path)
expected_chunks = {'doc_id': '123',
'documents': [image_path],
'embeddings': ['embedding'],
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
self.assertEqual(expected_chunks, result)
def test_word_count(self):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
chunker = ImagesChunker(config=chunker_config)
chunker.set_data_type(DataType.IMAGES)
document = [["ab cd", "ef gh"], ["ij kl", "mn op"]]
result = chunker.get_word_count(document)
self.assertEqual(result, 1)
class MockLoader:
def load_data(self, src):
"""
Mock loader that returns a list of data dictionaries.
Adjust this method to return different data for testing.
"""
return {
"doc_id": "123",
"data": [
{
"content": src,
"embedding": "embedding",
"meta_data": {"url": "none"},
}
],
}

View File

@@ -62,6 +62,15 @@ class TestTextChunker(unittest.TestCase):
self.assertEqual(len(documents), len(text))
def test_word_count(self):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
chunker = TextChunker(config=chunker_config)
chunker.set_data_type(DataType.TEXT)
document = ["ab cd", "ef gh"]
result = chunker.get_word_count(document)
self.assertEqual(result, 4)
class MockLoader:
def load_data(self, src):

BIN
tests/models/image.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

View File

@@ -0,0 +1,55 @@
import tempfile
import unittest
import os
import urllib
from PIL import Image
from embedchain.models.clip_processor import ClipProcessor
class ClipProcessorTest(unittest.TestCase):
def test_load_model(self):
# Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
model, preprocess = ClipProcessor.load_model()
# Assert that the model is not None.
self.assertIsNotNone(model)
# Assert that the preprocess is not None.
self.assertIsNotNone(preprocess)
def test_get_image_features(self):
# Clone the image to a temporary folder.
with tempfile.TemporaryDirectory() as tmp_dir:
urllib.request.urlretrieve(
'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg',
"image.jpg")
image = Image.open("image.jpg")
image.save(os.path.join(tmp_dir, "image.jpg"))
# Get the image features.
model, preprocess = ClipProcessor.load_model()
ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess)
# Delete the temporary file.
os.remove(os.path.join(tmp_dir, "image.jpg"))
# Assert that the test passes.
self.assertTrue(True)
def test_get_text_features(self):
# Test that the `get_text_features()` method returns a list containing the text embedding.
query = "This is a text query."
model, preprocess = ClipProcessor.load_model()
text_features = ClipProcessor.get_text_features(query)
# Assert that the text embedding is not None.
self.assertIsNotNone(text_features)
# Assert that the text embedding is a list of floats.
self.assertIsInstance(text_features, list)
# Assert that the text embedding has the correct length.
self.assertEqual(len(text_features), 512)

View File

@@ -186,6 +186,34 @@ class TestChromaDbCollection(unittest.TestCase):
# Should still be 1, not 2.
self.assertEqual(app.db.count(), 1)
def test_add_with_skip_embedding(self):
"""
Test that changes to one collection do not affect the other collection
"""
# Start with a clean app
self.app_with_settings.reset()
# app = App(config=AppConfig(collect_metrics=False), db=db)
# Collection should be empty when created
self.assertEqual(self.app_with_settings.db.count(), 0)
self.app_with_settings.db.add(embeddings=[[0, 0, 0]], documents=["document"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True)
# After adding, should contain one item
self.assertEqual(self.app_with_settings.db.count(), 1)
# Validate if the get utility of the database is working as expected
data = self.app_with_settings.db.get(["id"], limit=1)
expected_value = {'documents': ['document'],
'embeddings': None,
'ids': ['id'],
'metadatas': [{'value': 'somevalue'}]}
self.assertEqual(data, expected_value)
# Validate if the query utility of the database is working as expected
data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
expected_value = ['document']
self.assertEqual(data, expected_value)
def test_collections_are_persistent(self):
"""
Test that a collection can be picked up later.

View File

@@ -1,14 +1,109 @@
import os
import unittest
from unittest.mock import patch
from embedchain.config import ElasticsearchDBConfig
from embedchain import App
from embedchain.config import AppConfig, ElasticsearchDBConfig
from embedchain.vectordb.elasticsearch import ElasticsearchDB
from embedchain.embedder.gpt4all import GPT4AllEmbedder
class TestEsDB(unittest.TestCase):
def setUp(self):
self.es_config = ElasticsearchDBConfig(es_url="http://mock-url.net")
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_setUp(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
self.vector_dim = 384
app_config = AppConfig(collection_name=False, collect_metrics=False)
self.app = App(config=app_config, db=self.db)
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
app_config = AppConfig(collection_name=False, collect_metrics=False)
self.app = App(config=app_config, db=self.db, embedder=GPT4AllEmbedder())
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
# Create some dummy data.
embeddings = [[1, 2, 3], [4, 5, 6]]
documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}]
ids = ["doc_1", "doc_2"]
# Add the data to the database.
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
search_response = {"hits":
{"hits":
[
{
"_source": {"text": "This is a document."},
"_score": 0.9
},
{
"_source": {"text": "This is another document."},
"_score": 0.8
}
]
}
}
# Configure the mock client to return the mocked response.
mock_client.return_value.search.return_value = search_response
# Query the database for the documents that are most similar to the query "This is a document".
query = ["This is a document"]
results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
# Assert that the results are correct.
self.assertEqual(results, ["This is a document.", "This is another document."])
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query_with_skip_embedding(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
app_config = AppConfig(collection_name=False, collect_metrics=False)
self.app = App(config=app_config, db=self.db)
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
# Create some dummy data.
embeddings = [[1, 2, 3], [4, 5, 6]]
documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}]
ids = ["doc_1", "doc_2"]
# Add the data to the database.
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
search_response = {"hits":
{"hits":
[
{
"_source": {"text": "This is a document."},
"_score": 0.9
},
{
"_source": {"text": "This is another document."},
"_score": 0.8
}
]
}
}
# Configure the mock client to return the mocked response.
mock_client.return_value.search.return_value = search_response
# Query the database for the documents that are most similar to the query "This is a document".
query = ["This is a document"]
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
# Assert that the results are correct.
self.assertEqual(results, ["This is a document.", "This is another document."])
def test_init_without_url(self):
# Make sure it's not loaded from env