Add GPT4Vision Image loader (#1089)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -1,78 +0,0 @@
|
||||
import unittest
|
||||
|
||||
from embedchain.chunkers.images import ImagesChunker
|
||||
from embedchain.config import ChunkerConfig
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class TestImageChunker(unittest.TestCase):
|
||||
def test_chunks(self):
|
||||
"""
|
||||
Test the chunks generated by TextChunker.
|
||||
# TODO: Not a very precise test.
|
||||
"""
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||
chunker = ImagesChunker(config=chunker_config)
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.IMAGES)
|
||||
|
||||
image_path = "./tmp/image.jpeg"
|
||||
app_id = "app1"
|
||||
result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
|
||||
|
||||
expected_chunks = {
|
||||
"doc_id": f"{app_id}--123",
|
||||
"documents": [image_path],
|
||||
"embeddings": ["embedding"],
|
||||
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
|
||||
"metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
|
||||
}
|
||||
self.assertEqual(expected_chunks, result)
|
||||
|
||||
def test_chunks_with_default_config(self):
|
||||
"""
|
||||
Test the chunks generated by ImageChunker with default config.
|
||||
"""
|
||||
chunker = ImagesChunker()
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.IMAGES)
|
||||
|
||||
image_path = "./tmp/image.jpeg"
|
||||
app_id = "app1"
|
||||
result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
|
||||
|
||||
expected_chunks = {
|
||||
"doc_id": f"{app_id}--123",
|
||||
"documents": [image_path],
|
||||
"embeddings": ["embedding"],
|
||||
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
|
||||
"metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
|
||||
}
|
||||
self.assertEqual(expected_chunks, result)
|
||||
|
||||
def test_word_count(self):
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||
chunker = ImagesChunker(config=chunker_config)
|
||||
chunker.set_data_type(DataType.IMAGES)
|
||||
|
||||
document = [["ab cd", "ef gh"], ["ij kl", "mn op"]]
|
||||
result = chunker.get_word_count(document)
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
|
||||
class MockLoader:
|
||||
def load_data(self, src):
|
||||
"""
|
||||
Mock loader that returns a list of data dictionaries.
|
||||
Adjust this method to return different data for testing.
|
||||
"""
|
||||
return {
|
||||
"doc_id": "123",
|
||||
"data": [
|
||||
{
|
||||
"content": src,
|
||||
"embedding": "embedding",
|
||||
"meta_data": {"url": "none"},
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import urllib
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from embedchain.models.clip_processor import ClipProcessor
|
||||
|
||||
|
||||
class TestClipProcessor:
|
||||
def test_load_model(self):
|
||||
# Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
|
||||
model = ClipProcessor.load_model()
|
||||
assert model is not None
|
||||
|
||||
def test_get_image_features(self):
|
||||
# Clone the image to a temporary folder.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
|
||||
|
||||
image = Image.open("image.jpg")
|
||||
image.save(os.path.join(tmp_dir, "image.jpg"))
|
||||
|
||||
# Get the image features.
|
||||
model = ClipProcessor.load_model()
|
||||
ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model)
|
||||
|
||||
# Delete the temporary file.
|
||||
os.remove(os.path.join(tmp_dir, "image.jpg"))
|
||||
os.remove("image.jpg")
|
||||
|
||||
def test_get_text_features(self):
|
||||
# Test that the `get_text_features()` method returns a list containing the text embedding.
|
||||
query = "This is a text query."
|
||||
text_features = ClipProcessor.get_text_features(query)
|
||||
|
||||
# Assert that the text embedding is not None.
|
||||
assert text_features is not None
|
||||
|
||||
# Assert that the text embedding is a list of floats.
|
||||
assert isinstance(text_features, list)
|
||||
|
||||
# Assert that the text embedding has the correct length.
|
||||
assert len(text_features) == 512
|
||||
@@ -148,73 +148,6 @@ def test_chroma_db_collection_changes_encapsulated():
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
|
||||
# Start with a clean app
|
||||
app_with_settings.db.reset()
|
||||
|
||||
assert app_with_settings.db.count() == 0
|
||||
|
||||
app_with_settings.db.add(
|
||||
embeddings=[[0, 0, 0]],
|
||||
documents=["document"],
|
||||
metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||
ids=["id"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 1
|
||||
|
||||
data = app_with_settings.db.get(["id"], limit=1)
|
||||
expected_value = {
|
||||
"documents": ["document"],
|
||||
"embeddings": None,
|
||||
"ids": ["id"],
|
||||
"metadatas": [{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||
"data": None,
|
||||
"uris": None,
|
||||
}
|
||||
|
||||
assert data == expected_value
|
||||
|
||||
data_without_citations = app_with_settings.db.query(
|
||||
input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True
|
||||
)
|
||||
expected_value_without_citations = ["document"]
|
||||
assert data_without_citations == expected_value_without_citations
|
||||
|
||||
app_with_settings.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
|
||||
# Start with a clean app
|
||||
app_with_settings.db.reset()
|
||||
|
||||
assert app_with_settings.db.count() == 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
app_with_settings.db.add(
|
||||
embeddings=[[0, 0, 0]],
|
||||
documents=["document", "document2"],
|
||||
metadatas=[{"value": "somevalue"}],
|
||||
ids=["id"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
app_with_settings.db.add(
|
||||
embeddings=None,
|
||||
documents=["document", "document2"],
|
||||
metadatas=[{"value": "somevalue"}],
|
||||
ids=["id"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 0
|
||||
app_with_settings.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_collections_are_persistent():
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
@@ -312,60 +245,3 @@ def test_chroma_db_collection_reset():
|
||||
app2.db.reset()
|
||||
app3.db.reset()
|
||||
app4.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_query(app_with_settings):
|
||||
app_with_settings.db.reset()
|
||||
|
||||
assert app_with_settings.db.count() == 0
|
||||
|
||||
app_with_settings.db.add(
|
||||
embeddings=[[0, 0, 0]],
|
||||
documents=["document"],
|
||||
metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||
ids=["id"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 1
|
||||
|
||||
app_with_settings.db.add(
|
||||
embeddings=[[0, 1, 0]],
|
||||
documents=["document2"],
|
||||
metadatas=[{"url": "url_2", "doc_id": "doc_id_2"}],
|
||||
ids=["id2"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 2
|
||||
|
||||
data_without_citations = app_with_settings.db.query(
|
||||
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True
|
||||
)
|
||||
expected_value_without_citations = ["document", "document2"]
|
||||
assert data_without_citations == expected_value_without_citations
|
||||
|
||||
data_with_citations = app_with_settings.db.query(
|
||||
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
|
||||
)
|
||||
expected_value_with_citations = [
|
||||
(
|
||||
"document",
|
||||
{
|
||||
"url": "url_1",
|
||||
"doc_id": "doc_id_1",
|
||||
"score": 0.0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"document2",
|
||||
{
|
||||
"url": "url_2",
|
||||
"doc_id": "doc_id_2",
|
||||
"score": 1.0,
|
||||
},
|
||||
),
|
||||
]
|
||||
assert data_with_citations == expected_value_with_citations
|
||||
|
||||
app_with_settings.db.reset()
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestEsDB(unittest.TestCase):
|
||||
ids = ["doc_1", "doc_2"]
|
||||
|
||||
# Add the data to the database.
|
||||
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
|
||||
self.db.add(embeddings, documents, metadatas, ids)
|
||||
|
||||
search_response = {
|
||||
"hits": {
|
||||
@@ -60,63 +60,17 @@ class TestEsDB(unittest.TestCase):
|
||||
|
||||
# Query the database for the documents that are most similar to the query "This is a document".
|
||||
query = ["This is a document"]
|
||||
results_without_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
||||
results_without_citations = self.db.query(query, n_results=2, where={})
|
||||
expected_results_without_citations = ["This is a document.", "This is another document."]
|
||||
self.assertEqual(results_without_citations, expected_results_without_citations)
|
||||
|
||||
results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
|
||||
results_with_citations = self.db.query(query, n_results=2, where={}, citations=True)
|
||||
expected_results_with_citations = [
|
||||
("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}),
|
||||
("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}),
|
||||
]
|
||||
self.assertEqual(results_with_citations, expected_results_with_citations)
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query_with_skip_embedding(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db)
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
|
||||
# Create some dummy data.
|
||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
|
||||
ids = ["doc_1", "doc_2"]
|
||||
|
||||
# Add the data to the database.
|
||||
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
|
||||
|
||||
search_response = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
|
||||
"_score": 0.9,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"text": "This is another document.",
|
||||
"metadata": {"url": "url_2", "doc_id": "doc_id_2"},
|
||||
},
|
||||
"_score": 0.8,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# Configure the mock client to return the mocked response.
|
||||
mock_client.return_value.search.return_value = search_response
|
||||
|
||||
# Query the database for the documents that are most similar to the query "This is a document".
|
||||
query = ["This is a document"]
|
||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
|
||||
|
||||
# Assert that the results are correct.
|
||||
self.assertEqual(results, ["This is a document.", "This is another document."])
|
||||
|
||||
def test_init_without_url(self):
|
||||
# Make sure it's not loaded from env
|
||||
try:
|
||||
|
||||
@@ -54,7 +54,7 @@ class TestPinecone:
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["doc1", "doc2"]
|
||||
db.add(vectors, documents, metadatas, ids, True)
|
||||
db.add(vectors, documents, metadatas, ids)
|
||||
|
||||
expected_pinecone_upsert_args = [
|
||||
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
|
||||
@@ -81,7 +81,7 @@ class TestPinecone:
|
||||
# Query the database for documents that are similar to "document"
|
||||
input_query = ["document"]
|
||||
n_results = 1
|
||||
db.query(input_query, n_results, where={}, skip_embedding=False)
|
||||
db.query(input_query, n_results, where={})
|
||||
|
||||
# Assert that the Pinecone client was called to query the database
|
||||
pinecone_client_mock.query.assert_called_once_with(
|
||||
|
||||
@@ -12,6 +12,11 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.qdrant import QdrantDB
|
||||
|
||||
|
||||
def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
|
||||
"""A mock embedding function."""
|
||||
return [[1, 2, 3], [4, 5, 6]]
|
||||
|
||||
|
||||
class TestQdrantDB(unittest.TestCase):
|
||||
TEST_UUIDS = ["abc", "def", "ghi"]
|
||||
|
||||
@@ -25,6 +30,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
@@ -42,6 +48,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
@@ -61,6 +68,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
@@ -71,8 +79,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
documents = ["This is a test document.", "This is another test document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["123", "456"]
|
||||
skip_embedding = True
|
||||
db.add(embeddings, documents, metadatas, ids, skip_embedding)
|
||||
db.add(embeddings, documents, metadatas, ids)
|
||||
qdrant_client_mock.return_value.upsert.assert_called_once_with(
|
||||
collection_name="embedchain-store-1526",
|
||||
points=Batch(
|
||||
@@ -98,6 +105,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
@@ -105,7 +113,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
|
||||
|
||||
qdrant_client_mock.return_value.search.assert_called_once_with(
|
||||
collection_name="embedchain-store-1526",
|
||||
@@ -119,7 +127,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
)
|
||||
]
|
||||
),
|
||||
query_vector=["This is a test document."],
|
||||
query_vector=[1, 2, 3],
|
||||
limit=1,
|
||||
)
|
||||
|
||||
@@ -128,6 +136,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
@@ -142,6 +151,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
|
||||
@@ -8,6 +8,11 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.weaviate import WeaviateDB
|
||||
|
||||
|
||||
def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
|
||||
"""A mock embedding function."""
|
||||
return [[1, 2, 3], [4, 5, 6]]
|
||||
|
||||
|
||||
class TestWeaviateDb(unittest.TestCase):
|
||||
def test_incorrect_config_throws_error(self):
|
||||
"""Test the init method of the WeaviateDb class throws error for incorrect config"""
|
||||
@@ -25,6 +30,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -92,6 +98,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -111,6 +118,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -122,8 +130,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
documents = ["This is a test document.", "This is another test document."]
|
||||
metadatas = [None, None]
|
||||
ids = ["123", "456"]
|
||||
skip_embedding = True
|
||||
db.add(embeddings, documents, metadatas, ids, skip_embedding)
|
||||
db.add(embeddings, documents, metadatas, ids)
|
||||
|
||||
# Check if the document was added to the database.
|
||||
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
|
||||
@@ -155,6 +162,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -162,12 +170,10 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={})
|
||||
|
||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
|
||||
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with(
|
||||
{"vector": ["This is a test document."]}
|
||||
)
|
||||
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_query_with_where(self, weaviate_mock):
|
||||
@@ -180,6 +186,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -187,15 +194,13 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
|
||||
|
||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
|
||||
weaviate_client_query_get_mock.with_where.assert_called_once_with(
|
||||
{"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
|
||||
)
|
||||
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with(
|
||||
{"vector": ["This is a test document."]}
|
||||
)
|
||||
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_reset(self, weaviate_mock):
|
||||
@@ -206,6 +211,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -228,6 +234,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
|
||||
@@ -108,65 +108,7 @@ class TestZillizDBCollection:
|
||||
|
||||
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
||||
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
|
||||
def test_query_with_skip_embedding(self, mock_connect, mock_client, mock_config):
|
||||
"""
|
||||
Test if the `ZillizVectorDB` instance is takes in the query with skip_embeddings.
|
||||
"""
|
||||
# Create an instance of ZillizVectorDB with mock config
|
||||
zilliz_db = ZillizVectorDB(config=mock_config)
|
||||
|
||||
# Add a 'collection' attribute to the ZillizVectorDB instance for testing
|
||||
zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
|
||||
|
||||
assert zilliz_db.client == mock_client()
|
||||
|
||||
# Mock the MilvusClient search method
|
||||
with patch.object(zilliz_db.client, "search") as mock_search:
|
||||
# Mock the search result
|
||||
mock_search.return_value = [
|
||||
[
|
||||
{
|
||||
"distance": 0.5,
|
||||
"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
# Call the query method with skip_embedding=True
|
||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
|
||||
|
||||
# Assert that MilvusClient.search was called with the correct parameters
|
||||
mock_search.assert_called_with(
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_text"],
|
||||
limit=1,
|
||||
output_fields=["*"],
|
||||
)
|
||||
|
||||
# Assert that the query result matches the expected result
|
||||
assert query_result == ["result_doc"]
|
||||
|
||||
query_result_with_citations = zilliz_db.query(
|
||||
input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True
|
||||
)
|
||||
|
||||
mock_search.assert_called_with(
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_text"],
|
||||
limit=1,
|
||||
output_fields=["*"],
|
||||
)
|
||||
|
||||
assert query_result_with_citations == [
|
||||
("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.5})
|
||||
]
|
||||
|
||||
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
||||
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
|
||||
def test_query_without_skip_embedding(self, mock_connect, mock_client, mock_embedder, mock_config):
|
||||
"""
|
||||
Test if the `ZillizVectorDB` instance is takes in the query without skip_embeddings.
|
||||
"""
|
||||
def test_query(self, mock_connect, mock_client, mock_embedder, mock_config):
|
||||
# Create an instance of ZillizVectorDB with mock config
|
||||
zilliz_db = ZillizVectorDB(config=mock_config)
|
||||
|
||||
@@ -193,8 +135,7 @@ class TestZillizDBCollection:
|
||||
]
|
||||
]
|
||||
|
||||
# Call the query method with skip_embedding=False
|
||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
|
||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={})
|
||||
|
||||
# Assert that MilvusClient.search was called with the correct parameters
|
||||
mock_search.assert_called_with(
|
||||
@@ -208,7 +149,7 @@ class TestZillizDBCollection:
|
||||
assert query_result == ["result_doc"]
|
||||
|
||||
query_result_with_citations = zilliz_db.query(
|
||||
input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True
|
||||
input_query=["query_text"], n_results=1, where={}, citations=True
|
||||
)
|
||||
|
||||
mock_search.assert_called_with(
|
||||
|
||||
Reference in New Issue
Block a user