[bugfix] Fix issue when llm config is not defined (#763)

This commit is contained in:
Deshraj Yadav
2023-10-04 12:08:21 -07:00
committed by GitHub
parent d0af018b8d
commit 87d0b5c76f
15 changed files with 100 additions and 88 deletions

View File

@@ -19,11 +19,13 @@ class TestImageChunker(unittest.TestCase):
image_path = "./tmp/image.jpeg"
result = chunker.create_chunks(MockLoader(), image_path)
expected_chunks = {'doc_id': '123',
'documents': [image_path],
'embeddings': ['embedding'],
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
expected_chunks = {
"doc_id": "123",
"documents": [image_path],
"embeddings": ["embedding"],
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
}
self.assertEqual(expected_chunks, result)
def test_chunks_with_default_config(self):
@@ -37,11 +39,13 @@ class TestImageChunker(unittest.TestCase):
image_path = "./tmp/image.jpeg"
result = chunker.create_chunks(MockLoader(), image_path)
expected_chunks = {'doc_id': '123',
'documents': [image_path],
'embeddings': ['embedding'],
'ids': ['140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe'],
'metadatas': [{'data_type': 'images', 'doc_id': '123', 'url': 'none'}]}
expected_chunks = {
"doc_id": "123",
"documents": [image_path],
"embeddings": ["embedding"],
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
}
self.assertEqual(expected_chunks, result)
def test_word_count(self):

View File

@@ -1,29 +1,23 @@
import tempfile
import unittest
import os
import tempfile
import urllib
from PIL import Image
from embedchain.models.clip_processor import ClipProcessor
class ClipProcessorTest(unittest.TestCase):
class TestClipProcessor:
def test_load_model(self):
# Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
model, preprocess = ClipProcessor.load_model()
# Assert that the model is not None.
self.assertIsNotNone(model)
# Assert that the preprocess is not None.
self.assertIsNotNone(preprocess)
assert model is not None
assert preprocess is not None
def test_get_image_features(self):
# Clone the image to a temporary folder.
with tempfile.TemporaryDirectory() as tmp_dir:
urllib.request.urlretrieve(
'https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg',
"image.jpg")
urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
image = Image.open("image.jpg")
image.save(os.path.join(tmp_dir, "image.jpg"))
@@ -35,9 +29,6 @@ class ClipProcessorTest(unittest.TestCase):
# Delete the temporary file.
os.remove(os.path.join(tmp_dir, "image.jpg"))
# Assert that the test passes.
self.assertTrue(True)
def test_get_text_features(self):
# Test that the `get_text_features()` method returns a list containing the text embedding.
query = "This is a text query."
@@ -46,10 +37,10 @@ class ClipProcessorTest(unittest.TestCase):
text_features = ClipProcessor.get_text_features(query)
# Assert that the text embedding is not None.
self.assertIsNotNone(text_features)
assert text_features is not None
# Assert that the text embedding is a list of floats.
self.assertIsInstance(text_features, list)
assert isinstance(text_features, list)
# Assert that the text embedding has the correct length.
self.assertEqual(len(text_features), 512)
assert len(text_features) == 512

View File

@@ -197,21 +197,29 @@ class TestChromaDbCollection(unittest.TestCase):
# Collection should be empty when created
self.assertEqual(self.app_with_settings.db.count(), 0)
self.app_with_settings.db.add(embeddings=[[0, 0, 0]], documents=["document"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True)
self.app_with_settings.db.add(
embeddings=[[0, 0, 0]],
documents=["document"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
# After adding, should contain one item
self.assertEqual(self.app_with_settings.db.count(), 1)
# Validate if the get utility of the database is working as expected
data = self.app_with_settings.db.get(["id"], limit=1)
expected_value = {'documents': ['document'],
'embeddings': None,
'ids': ['id'],
'metadatas': [{'value': 'somevalue'}]}
expected_value = {
"documents": ["document"],
"embeddings": None,
"ids": ["id"],
"metadatas": [{"value": "somevalue"}],
}
self.assertEqual(data, expected_value)
# Validate if the query utility of the database is working as expected
data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
expected_value = ['document']
expected_value = ["document"]
self.assertEqual(data, expected_value)
def test_collections_are_persistent(self):

View File

@@ -4,11 +4,11 @@ from unittest.mock import patch
from embedchain import App
from embedchain.config import AppConfig, ElasticsearchDBConfig
from embedchain.vectordb.elasticsearch import ElasticsearchDB
from embedchain.embedder.gpt4all import GPT4AllEmbedder
from embedchain.vectordb.elasticsearch import ElasticsearchDB
class TestEsDB(unittest.TestCase):
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_setUp(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
@@ -37,17 +37,11 @@ class TestEsDB(unittest.TestCase):
# Add the data to the database.
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
search_response = {"hits":
{"hits":
[
{
"_source": {"text": "This is a document."},
"_score": 0.9
},
{
"_source": {"text": "This is another document."},
"_score": 0.8
}
search_response = {
"hits": {
"hits": [
{"_source": {"text": "This is a document."}, "_score": 0.9},
{"_source": {"text": "This is another document."}, "_score": 0.8},
]
}
}
@@ -80,17 +74,11 @@ class TestEsDB(unittest.TestCase):
# Add the data to the database.
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
search_response = {"hits":
{"hits":
[
{
"_source": {"text": "This is a document."},
"_score": 0.9
},
{
"_source": {"text": "This is another document."},
"_score": 0.8
}
search_response = {
"hits": {
"hits": [
{"_source": {"text": "This is a document."}, "_score": 0.9},
{"_source": {"text": "This is another document."}, "_score": 0.8},
]
}
}