Add support for image dataset (#571)
Co-authored-by: Rupesh Bansal <rupeshbansal@Shankars-MacBook-Air.local>
This commit is contained in:
72
tests/chunkers/test_image_chunker.py
Normal file
72
tests/chunkers/test_image_chunker.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import unittest
|
||||
|
||||
from embedchain.chunkers.images import ImagesChunker
|
||||
from embedchain.config import ChunkerConfig
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class TestImageChunker(unittest.TestCase):
|
||||
def test_chunks(self):
|
||||
"""
|
||||
Test the chunks generated by TextChunker.
|
||||
# TODO: Not a very precise test.
|
||||
"""
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
||||
chunker = ImagesChunker(config=chunker_config)
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.IMAGES)
|
||||
|
||||
image_path = "./tmp/image.jpeg"
|
||||
result = chunker.create_chunks(MockLoader(), image_path)
|
||||
|
||||
expected_chunks = {'doc_id': '123',
|
||||
'documents': [image_path],
|
||||
'embeddings': ['embedding'],
|
||||
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
|
||||
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
|
||||
self.assertEqual(expected_chunks, result)
|
||||
|
||||
def test_chunks_with_default_config(self):
|
||||
"""
|
||||
Test the chunks generated by ImageChunker with default config.
|
||||
"""
|
||||
chunker = ImagesChunker()
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.IMAGES)
|
||||
|
||||
image_path = "./tmp/image.jpeg"
|
||||
result = chunker.create_chunks(MockLoader(), image_path)
|
||||
|
||||
expected_chunks = {'doc_id': '123',
|
||||
'documents': [image_path],
|
||||
'embeddings': ['embedding'],
|
||||
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
|
||||
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
|
||||
self.assertEqual(expected_chunks, result)
|
||||
|
||||
def test_word_count(self):
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
||||
chunker = ImagesChunker(config=chunker_config)
|
||||
chunker.set_data_type(DataType.IMAGES)
|
||||
|
||||
document = [["ab cd", "ef gh"], ["ij kl", "mn op"]]
|
||||
result = chunker.get_word_count(document)
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
|
||||
class MockLoader:
|
||||
def load_data(self, src):
|
||||
"""
|
||||
Mock loader that returns a list of data dictionaries.
|
||||
Adjust this method to return different data for testing.
|
||||
"""
|
||||
return {
|
||||
"doc_id": "123",
|
||||
"data": [
|
||||
{
|
||||
"content": src,
|
||||
"embedding": "embedding",
|
||||
"meta_data": {"url": "none"},
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -62,6 +62,15 @@ class TestTextChunker(unittest.TestCase):
|
||||
|
||||
self.assertEqual(len(documents), len(text))
|
||||
|
||||
def test_word_count(self):
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
||||
chunker = TextChunker(config=chunker_config)
|
||||
chunker.set_data_type(DataType.TEXT)
|
||||
|
||||
document = ["ab cd", "ef gh"]
|
||||
result = chunker.get_word_count(document)
|
||||
self.assertEqual(result, 4)
|
||||
|
||||
|
||||
class MockLoader:
|
||||
def load_data(self, src):
|
||||
|
||||
BIN
tests/models/image.jpg
Normal file
BIN
tests/models/image.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 27 KiB |
55
tests/models/test_clip_processor.py
Normal file
55
tests/models/test_clip_processor.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
import os
|
||||
import urllib
|
||||
from PIL import Image
|
||||
from embedchain.models.clip_processor import ClipProcessor
|
||||
|
||||
|
||||
class ClipProcessorTest(unittest.TestCase):
|
||||
|
||||
def test_load_model(self):
|
||||
# Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
|
||||
model, preprocess = ClipProcessor.load_model()
|
||||
|
||||
# Assert that the model is not None.
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
# Assert that the preprocess is not None.
|
||||
self.assertIsNotNone(preprocess)
|
||||
|
||||
def test_get_image_features(self):
|
||||
# Clone the image to a temporary folder.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
urllib.request.urlretrieve(
|
||||
'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg',
|
||||
"image.jpg")
|
||||
|
||||
image = Image.open("image.jpg")
|
||||
image.save(os.path.join(tmp_dir, "image.jpg"))
|
||||
|
||||
# Get the image features.
|
||||
model, preprocess = ClipProcessor.load_model()
|
||||
ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model, preprocess)
|
||||
|
||||
# Delete the temporary file.
|
||||
os.remove(os.path.join(tmp_dir, "image.jpg"))
|
||||
|
||||
# Assert that the test passes.
|
||||
self.assertTrue(True)
|
||||
|
||||
def test_get_text_features(self):
|
||||
# Test that the `get_text_features()` method returns a list containing the text embedding.
|
||||
query = "This is a text query."
|
||||
model, preprocess = ClipProcessor.load_model()
|
||||
|
||||
text_features = ClipProcessor.get_text_features(query)
|
||||
|
||||
# Assert that the text embedding is not None.
|
||||
self.assertIsNotNone(text_features)
|
||||
|
||||
# Assert that the text embedding is a list of floats.
|
||||
self.assertIsInstance(text_features, list)
|
||||
|
||||
# Assert that the text embedding has the correct length.
|
||||
self.assertEqual(len(text_features), 512)
|
||||
@@ -186,6 +186,34 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
# Should still be 1, not 2.
|
||||
self.assertEqual(app.db.count(), 1)
|
||||
|
||||
def test_add_with_skip_embedding(self):
|
||||
"""
|
||||
Test that changes to one collection do not affect the other collection
|
||||
"""
|
||||
# Start with a clean app
|
||||
self.app_with_settings.reset()
|
||||
# app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
|
||||
# Collection should be empty when created
|
||||
self.assertEqual(self.app_with_settings.db.count(), 0)
|
||||
|
||||
self.app_with_settings.db.add(embeddings=[[0, 0, 0]], documents=["document"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True)
|
||||
# After adding, should contain one item
|
||||
self.assertEqual(self.app_with_settings.db.count(), 1)
|
||||
|
||||
# Validate if the get utility of the database is working as expected
|
||||
data = self.app_with_settings.db.get(["id"], limit=1)
|
||||
expected_value = {'documents': ['document'],
|
||||
'embeddings': None,
|
||||
'ids': ['id'],
|
||||
'metadatas': [{'value': 'somevalue'}]}
|
||||
self.assertEqual(data, expected_value)
|
||||
|
||||
# Validate if the query utility of the database is working as expected
|
||||
data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
|
||||
expected_value = ['document']
|
||||
self.assertEqual(data, expected_value)
|
||||
|
||||
def test_collections_are_persistent(self):
|
||||
"""
|
||||
Test that a collection can be picked up later.
|
||||
|
||||
@@ -1,14 +1,109 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain.config import ElasticsearchDBConfig
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ElasticsearchDBConfig
|
||||
from embedchain.vectordb.elasticsearch import ElasticsearchDB
|
||||
|
||||
from embedchain.embedder.gpt4all import GPT4AllEmbedder
|
||||
|
||||
class TestEsDB(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.es_config = ElasticsearchDBConfig(es_url="http://mock-url.net")
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_setUp(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
self.vector_dim = 384
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db)
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db, embedder=GPT4AllEmbedder())
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
|
||||
# Create some dummy data.
|
||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["doc_1", "doc_2"]
|
||||
|
||||
# Add the data to the database.
|
||||
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
|
||||
|
||||
search_response = {"hits":
|
||||
{"hits":
|
||||
[
|
||||
{
|
||||
"_source": {"text": "This is a document."},
|
||||
"_score": 0.9
|
||||
},
|
||||
{
|
||||
"_source": {"text": "This is another document."},
|
||||
"_score": 0.8
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# Configure the mock client to return the mocked response.
|
||||
mock_client.return_value.search.return_value = search_response
|
||||
|
||||
# Query the database for the documents that are most similar to the query "This is a document".
|
||||
query = ["This is a document"]
|
||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
||||
|
||||
# Assert that the results are correct.
|
||||
self.assertEqual(results, ["This is a document.", "This is another document."])
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query_with_skip_embedding(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db)
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
|
||||
# Create some dummy data.
|
||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["doc_1", "doc_2"]
|
||||
|
||||
# Add the data to the database.
|
||||
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
|
||||
|
||||
search_response = {"hits":
|
||||
{"hits":
|
||||
[
|
||||
{
|
||||
"_source": {"text": "This is a document."},
|
||||
"_score": 0.9
|
||||
},
|
||||
{
|
||||
"_source": {"text": "This is another document."},
|
||||
"_score": 0.8
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# Configure the mock client to return the mocked response.
|
||||
mock_client.return_value.search.return_value = search_response
|
||||
|
||||
# Query the database for the documents that are most similar to the query "This is a document".
|
||||
query = ["This is a document"]
|
||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
|
||||
|
||||
# Assert that the results are correct.
|
||||
self.assertEqual(results, ["This is a document.", "This is another document."])
|
||||
|
||||
def test_init_without_url(self):
|
||||
# Make sure it's not loaded from env
|
||||
|
||||
Reference in New Issue
Block a user