Add GPT4Vision Image loader (#1089)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Sidharth Mohanty
2024-01-02 03:57:23 +05:30
committed by GitHub
parent 367d6b70e2
commit c62663f2e4
29 changed files with 291 additions and 714 deletions

View File

@@ -1,78 +0,0 @@
import unittest
from embedchain.chunkers.images import ImagesChunker
from embedchain.config import ChunkerConfig
from embedchain.models.data_type import DataType
class TestImageChunker(unittest.TestCase):
def test_chunks(self):
"""
Test the chunks generated by TextChunker.
# TODO: Not a very precise test.
"""
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = ImagesChunker(config=chunker_config)
# Data type must be set manually in the test
chunker.set_data_type(DataType.IMAGES)
image_path = "./tmp/image.jpeg"
app_id = "app1"
result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
expected_chunks = {
"doc_id": f"{app_id}--123",
"documents": [image_path],
"embeddings": ["embedding"],
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
"metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
}
self.assertEqual(expected_chunks, result)
def test_chunks_with_default_config(self):
"""
Test the chunks generated by ImageChunker with default config.
"""
chunker = ImagesChunker()
# Data type must be set manually in the test
chunker.set_data_type(DataType.IMAGES)
image_path = "./tmp/image.jpeg"
app_id = "app1"
result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
expected_chunks = {
"doc_id": f"{app_id}--123",
"documents": [image_path],
"embeddings": ["embedding"],
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
"metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
}
self.assertEqual(expected_chunks, result)
def test_word_count(self):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = ImagesChunker(config=chunker_config)
chunker.set_data_type(DataType.IMAGES)
document = [["ab cd", "ef gh"], ["ij kl", "mn op"]]
result = chunker.get_word_count(document)
self.assertEqual(result, 1)
class MockLoader:
def load_data(self, src):
"""
Mock loader that returns a list of data dictionaries.
Adjust this method to return different data for testing.
"""
return {
"doc_id": "123",
"data": [
{
"content": src,
"embedding": "embedding",
"meta_data": {"url": "none"},
}
],
}

View File

@@ -1,44 +0,0 @@
import os
import tempfile
import urllib
from PIL import Image
from embedchain.models.clip_processor import ClipProcessor
class TestClipProcessor:
def test_load_model(self):
# Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
model = ClipProcessor.load_model()
assert model is not None
def test_get_image_features(self):
# Clone the image to a temporary folder.
with tempfile.TemporaryDirectory() as tmp_dir:
urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
image = Image.open("image.jpg")
image.save(os.path.join(tmp_dir, "image.jpg"))
# Get the image features.
model = ClipProcessor.load_model()
ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model)
# Delete the temporary file.
os.remove(os.path.join(tmp_dir, "image.jpg"))
os.remove("image.jpg")
def test_get_text_features(self):
# Test that the `get_text_features()` method returns a list containing the text embedding.
query = "This is a text query."
text_features = ClipProcessor.get_text_features(query)
# Assert that the text embedding is not None.
assert text_features is not None
# Assert that the text embedding is a list of floats.
assert isinstance(text_features, list)
# Assert that the text embedding has the correct length.
assert len(text_features) == 512

View File

@@ -148,73 +148,6 @@ def test_chroma_db_collection_changes_encapsulated():
app.db.reset()
def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
# Start with a clean app
app_with_settings.db.reset()
assert app_with_settings.db.count() == 0
app_with_settings.db.add(
embeddings=[[0, 0, 0]],
documents=["document"],
metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
ids=["id"],
skip_embedding=True,
)
assert app_with_settings.db.count() == 1
data = app_with_settings.db.get(["id"], limit=1)
expected_value = {
"documents": ["document"],
"embeddings": None,
"ids": ["id"],
"metadatas": [{"url": "url_1", "doc_id": "doc_id_1"}],
"data": None,
"uris": None,
}
assert data == expected_value
data_without_citations = app_with_settings.db.query(
input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True
)
expected_value_without_citations = ["document"]
assert data_without_citations == expected_value_without_citations
app_with_settings.db.reset()
def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
# Start with a clean app
app_with_settings.db.reset()
assert app_with_settings.db.count() == 0
with pytest.raises(ValueError):
app_with_settings.db.add(
embeddings=[[0, 0, 0]],
documents=["document", "document2"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
assert app_with_settings.db.count() == 0
with pytest.raises(ValueError):
app_with_settings.db.add(
embeddings=None,
documents=["document", "document2"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
assert app_with_settings.db.count() == 0
app_with_settings.db.reset()
def test_chroma_db_collection_collections_are_persistent():
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
@@ -312,60 +245,3 @@ def test_chroma_db_collection_reset():
app2.db.reset()
app3.db.reset()
app4.db.reset()
def test_chroma_db_collection_query(app_with_settings):
app_with_settings.db.reset()
assert app_with_settings.db.count() == 0
app_with_settings.db.add(
embeddings=[[0, 0, 0]],
documents=["document"],
metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
ids=["id"],
skip_embedding=True,
)
assert app_with_settings.db.count() == 1
app_with_settings.db.add(
embeddings=[[0, 1, 0]],
documents=["document2"],
metadatas=[{"url": "url_2", "doc_id": "doc_id_2"}],
ids=["id2"],
skip_embedding=True,
)
assert app_with_settings.db.count() == 2
data_without_citations = app_with_settings.db.query(
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True
)
expected_value_without_citations = ["document", "document2"]
assert data_without_citations == expected_value_without_citations
data_with_citations = app_with_settings.db.query(
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
)
expected_value_with_citations = [
(
"document",
{
"url": "url_1",
"doc_id": "doc_id_1",
"score": 0.0,
},
),
(
"document2",
{
"url": "url_2",
"doc_id": "doc_id_2",
"score": 1.0,
},
),
]
assert data_with_citations == expected_value_with_citations
app_with_settings.db.reset()

View File

@@ -35,7 +35,7 @@ class TestEsDB(unittest.TestCase):
ids = ["doc_1", "doc_2"]
# Add the data to the database.
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
self.db.add(embeddings, documents, metadatas, ids)
search_response = {
"hits": {
@@ -60,63 +60,17 @@ class TestEsDB(unittest.TestCase):
# Query the database for the documents that are most similar to the query "This is a document".
query = ["This is a document"]
results_without_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False)
results_without_citations = self.db.query(query, n_results=2, where={})
expected_results_without_citations = ["This is a document.", "This is another document."]
self.assertEqual(results_without_citations, expected_results_without_citations)
results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
results_with_citations = self.db.query(query, n_results=2, where={}, citations=True)
expected_results_with_citations = [
("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}),
("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}),
]
self.assertEqual(results_with_citations, expected_results_with_citations)
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query_with_skip_embedding(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
app_config = AppConfig(collect_metrics=False)
self.app = App(config=app_config, db=self.db)
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
# Create some dummy data.
embeddings = [[1, 2, 3], [4, 5, 6]]
documents = ["This is a document.", "This is another document."]
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
ids = ["doc_1", "doc_2"]
# Add the data to the database.
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
search_response = {
"hits": {
"hits": [
{
"_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
"_score": 0.9,
},
{
"_source": {
"text": "This is another document.",
"metadata": {"url": "url_2", "doc_id": "doc_id_2"},
},
"_score": 0.8,
},
]
}
}
# Configure the mock client to return the mocked response.
mock_client.return_value.search.return_value = search_response
# Query the database for the documents that are most similar to the query "This is a document".
query = ["This is a document"]
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
# Assert that the results are correct.
self.assertEqual(results, ["This is a document.", "This is another document."])
def test_init_without_url(self):
# Make sure it's not loaded from env
try:

View File

@@ -54,7 +54,7 @@ class TestPinecone:
documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}]
ids = ["doc1", "doc2"]
db.add(vectors, documents, metadatas, ids, True)
db.add(vectors, documents, metadatas, ids)
expected_pinecone_upsert_args = [
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
@@ -81,7 +81,7 @@ class TestPinecone:
# Query the database for documents that are similar to "document"
input_query = ["document"]
n_results = 1
db.query(input_query, n_results, where={}, skip_embedding=False)
db.query(input_query, n_results, where={})
# Assert that the Pinecone client was called to query the database
pinecone_client_mock.query.assert_called_once_with(

View File

@@ -12,6 +12,11 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.qdrant import QdrantDB
def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
"""A mock embedding function."""
return [[1, 2, 3], [4, 5, 6]]
class TestQdrantDB(unittest.TestCase):
TEST_UUIDS = ["abc", "def", "ghi"]
@@ -25,6 +30,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
@@ -42,6 +48,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
@@ -61,6 +68,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
@@ -71,8 +79,7 @@ class TestQdrantDB(unittest.TestCase):
documents = ["This is a test document.", "This is another test document."]
metadatas = [{}, {}]
ids = ["123", "456"]
skip_embedding = True
db.add(embeddings, documents, metadatas, ids, skip_embedding)
db.add(embeddings, documents, metadatas, ids)
qdrant_client_mock.return_value.upsert.assert_called_once_with(
collection_name="embedchain-store-1526",
points=Batch(
@@ -98,6 +105,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
@@ -105,7 +113,7 @@ class TestQdrantDB(unittest.TestCase):
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
qdrant_client_mock.return_value.search.assert_called_once_with(
collection_name="embedchain-store-1526",
@@ -119,7 +127,7 @@ class TestQdrantDB(unittest.TestCase):
)
]
),
query_vector=["This is a test document."],
query_vector=[1, 2, 3],
limit=1,
)
@@ -128,6 +136,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
@@ -142,6 +151,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()

View File

@@ -8,6 +8,11 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.weaviate import WeaviateDB
def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
"""A mock embedding function."""
return [[1, 2, 3], [4, 5, 6]]
class TestWeaviateDb(unittest.TestCase):
def test_incorrect_config_throws_error(self):
"""Test the init method of the WeaviateDb class throws error for incorrect config"""
@@ -25,6 +30,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -92,6 +98,7 @@ class TestWeaviateDb(unittest.TestCase):
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -111,6 +118,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -122,8 +130,7 @@ class TestWeaviateDb(unittest.TestCase):
documents = ["This is a test document.", "This is another test document."]
metadatas = [None, None]
ids = ["123", "456"]
skip_embedding = True
db.add(embeddings, documents, metadatas, ids, skip_embedding)
db.add(embeddings, documents, metadatas, ids)
# Check if the document was added to the database.
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
@@ -155,6 +162,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -162,12 +170,10 @@ class TestWeaviateDb(unittest.TestCase):
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
db.query(input_query=["This is a test document."], n_results=1, where={})
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with(
{"vector": ["This is a test document."]}
)
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
@patch("embedchain.vectordb.weaviate.weaviate")
def test_query_with_where(self, weaviate_mock):
@@ -180,6 +186,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -187,15 +194,13 @@ class TestWeaviateDb(unittest.TestCase):
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
weaviate_client_query_get_mock.with_where.assert_called_once_with(
{"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
)
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with(
{"vector": ["This is a test document."]}
)
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
@patch("embedchain.vectordb.weaviate.weaviate")
def test_reset(self, weaviate_mock):
@@ -206,6 +211,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -228,6 +234,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()

View File

@@ -108,65 +108,7 @@ class TestZillizDBCollection:
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
def test_query_with_skip_embedding(self, mock_connect, mock_client, mock_config):
"""
Test if the `ZillizVectorDB` instance is takes in the query with skip_embeddings.
"""
# Create an instance of ZillizVectorDB with mock config
zilliz_db = ZillizVectorDB(config=mock_config)
# Add a 'collection' attribute to the ZillizVectorDB instance for testing
zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
assert zilliz_db.client == mock_client()
# Mock the MilvusClient search method
with patch.object(zilliz_db.client, "search") as mock_search:
# Mock the search result
mock_search.return_value = [
[
{
"distance": 0.5,
"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
}
]
]
# Call the query method with skip_embedding=True
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
# Assert that MilvusClient.search was called with the correct parameters
mock_search.assert_called_with(
collection_name=mock_config.collection_name,
data=["query_text"],
limit=1,
output_fields=["*"],
)
# Assert that the query result matches the expected result
assert query_result == ["result_doc"]
query_result_with_citations = zilliz_db.query(
input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True
)
mock_search.assert_called_with(
collection_name=mock_config.collection_name,
data=["query_text"],
limit=1,
output_fields=["*"],
)
assert query_result_with_citations == [
("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.5})
]
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
def test_query_without_skip_embedding(self, mock_connect, mock_client, mock_embedder, mock_config):
"""
Test if the `ZillizVectorDB` instance is takes in the query without skip_embeddings.
"""
def test_query(self, mock_connect, mock_client, mock_embedder, mock_config):
# Create an instance of ZillizVectorDB with mock config
zilliz_db = ZillizVectorDB(config=mock_config)
@@ -193,8 +135,7 @@ class TestZillizDBCollection:
]
]
# Call the query method with skip_embedding=False
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={})
# Assert that MilvusClient.search was called with the correct parameters
mock_search.assert_called_with(
@@ -208,7 +149,7 @@ class TestZillizDBCollection:
assert query_result == ["result_doc"]
query_result_with_citations = zilliz_db.query(
input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True
input_query=["query_text"], n_results=1, where={}, citations=True
)
mock_search.assert_called_with(