Add GPT4Vision Image loader (#1089)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -148,73 +148,6 @@ def test_chroma_db_collection_changes_encapsulated():
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
|
||||
# Start with a clean app
|
||||
app_with_settings.db.reset()
|
||||
|
||||
assert app_with_settings.db.count() == 0
|
||||
|
||||
app_with_settings.db.add(
|
||||
embeddings=[[0, 0, 0]],
|
||||
documents=["document"],
|
||||
metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||
ids=["id"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 1
|
||||
|
||||
data = app_with_settings.db.get(["id"], limit=1)
|
||||
expected_value = {
|
||||
"documents": ["document"],
|
||||
"embeddings": None,
|
||||
"ids": ["id"],
|
||||
"metadatas": [{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||
"data": None,
|
||||
"uris": None,
|
||||
}
|
||||
|
||||
assert data == expected_value
|
||||
|
||||
data_without_citations = app_with_settings.db.query(
|
||||
input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True
|
||||
)
|
||||
expected_value_without_citations = ["document"]
|
||||
assert data_without_citations == expected_value_without_citations
|
||||
|
||||
app_with_settings.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
|
||||
# Start with a clean app
|
||||
app_with_settings.db.reset()
|
||||
|
||||
assert app_with_settings.db.count() == 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
app_with_settings.db.add(
|
||||
embeddings=[[0, 0, 0]],
|
||||
documents=["document", "document2"],
|
||||
metadatas=[{"value": "somevalue"}],
|
||||
ids=["id"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
app_with_settings.db.add(
|
||||
embeddings=None,
|
||||
documents=["document", "document2"],
|
||||
metadatas=[{"value": "somevalue"}],
|
||||
ids=["id"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 0
|
||||
app_with_settings.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_collections_are_persistent():
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
@@ -312,60 +245,3 @@ def test_chroma_db_collection_reset():
|
||||
app2.db.reset()
|
||||
app3.db.reset()
|
||||
app4.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_query(app_with_settings):
|
||||
app_with_settings.db.reset()
|
||||
|
||||
assert app_with_settings.db.count() == 0
|
||||
|
||||
app_with_settings.db.add(
|
||||
embeddings=[[0, 0, 0]],
|
||||
documents=["document"],
|
||||
metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||
ids=["id"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 1
|
||||
|
||||
app_with_settings.db.add(
|
||||
embeddings=[[0, 1, 0]],
|
||||
documents=["document2"],
|
||||
metadatas=[{"url": "url_2", "doc_id": "doc_id_2"}],
|
||||
ids=["id2"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 2
|
||||
|
||||
data_without_citations = app_with_settings.db.query(
|
||||
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True
|
||||
)
|
||||
expected_value_without_citations = ["document", "document2"]
|
||||
assert data_without_citations == expected_value_without_citations
|
||||
|
||||
data_with_citations = app_with_settings.db.query(
|
||||
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
|
||||
)
|
||||
expected_value_with_citations = [
|
||||
(
|
||||
"document",
|
||||
{
|
||||
"url": "url_1",
|
||||
"doc_id": "doc_id_1",
|
||||
"score": 0.0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"document2",
|
||||
{
|
||||
"url": "url_2",
|
||||
"doc_id": "doc_id_2",
|
||||
"score": 1.0,
|
||||
},
|
||||
),
|
||||
]
|
||||
assert data_with_citations == expected_value_with_citations
|
||||
|
||||
app_with_settings.db.reset()
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestEsDB(unittest.TestCase):
|
||||
ids = ["doc_1", "doc_2"]
|
||||
|
||||
# Add the data to the database.
|
||||
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
|
||||
self.db.add(embeddings, documents, metadatas, ids)
|
||||
|
||||
search_response = {
|
||||
"hits": {
|
||||
@@ -60,63 +60,17 @@ class TestEsDB(unittest.TestCase):
|
||||
|
||||
# Query the database for the documents that are most similar to the query "This is a document".
|
||||
query = ["This is a document"]
|
||||
results_without_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
||||
results_without_citations = self.db.query(query, n_results=2, where={})
|
||||
expected_results_without_citations = ["This is a document.", "This is another document."]
|
||||
self.assertEqual(results_without_citations, expected_results_without_citations)
|
||||
|
||||
results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
|
||||
results_with_citations = self.db.query(query, n_results=2, where={}, citations=True)
|
||||
expected_results_with_citations = [
|
||||
("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}),
|
||||
("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}),
|
||||
]
|
||||
self.assertEqual(results_with_citations, expected_results_with_citations)
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query_with_skip_embedding(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db)
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
|
||||
# Create some dummy data.
|
||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
|
||||
ids = ["doc_1", "doc_2"]
|
||||
|
||||
# Add the data to the database.
|
||||
self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
|
||||
|
||||
search_response = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
|
||||
"_score": 0.9,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"text": "This is another document.",
|
||||
"metadata": {"url": "url_2", "doc_id": "doc_id_2"},
|
||||
},
|
||||
"_score": 0.8,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# Configure the mock client to return the mocked response.
|
||||
mock_client.return_value.search.return_value = search_response
|
||||
|
||||
# Query the database for the documents that are most similar to the query "This is a document".
|
||||
query = ["This is a document"]
|
||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
|
||||
|
||||
# Assert that the results are correct.
|
||||
self.assertEqual(results, ["This is a document.", "This is another document."])
|
||||
|
||||
def test_init_without_url(self):
|
||||
# Make sure it's not loaded from env
|
||||
try:
|
||||
|
||||
@@ -54,7 +54,7 @@ class TestPinecone:
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["doc1", "doc2"]
|
||||
db.add(vectors, documents, metadatas, ids, True)
|
||||
db.add(vectors, documents, metadatas, ids)
|
||||
|
||||
expected_pinecone_upsert_args = [
|
||||
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
|
||||
@@ -81,7 +81,7 @@ class TestPinecone:
|
||||
# Query the database for documents that are similar to "document"
|
||||
input_query = ["document"]
|
||||
n_results = 1
|
||||
db.query(input_query, n_results, where={}, skip_embedding=False)
|
||||
db.query(input_query, n_results, where={})
|
||||
|
||||
# Assert that the Pinecone client was called to query the database
|
||||
pinecone_client_mock.query.assert_called_once_with(
|
||||
|
||||
@@ -12,6 +12,11 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.qdrant import QdrantDB
|
||||
|
||||
|
||||
def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
|
||||
"""A mock embedding function."""
|
||||
return [[1, 2, 3], [4, 5, 6]]
|
||||
|
||||
|
||||
class TestQdrantDB(unittest.TestCase):
|
||||
TEST_UUIDS = ["abc", "def", "ghi"]
|
||||
|
||||
@@ -25,6 +30,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
@@ -42,6 +48,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
@@ -61,6 +68,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
@@ -71,8 +79,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
documents = ["This is a test document.", "This is another test document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["123", "456"]
|
||||
skip_embedding = True
|
||||
db.add(embeddings, documents, metadatas, ids, skip_embedding)
|
||||
db.add(embeddings, documents, metadatas, ids)
|
||||
qdrant_client_mock.return_value.upsert.assert_called_once_with(
|
||||
collection_name="embedchain-store-1526",
|
||||
points=Batch(
|
||||
@@ -98,6 +105,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
@@ -105,7 +113,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
|
||||
|
||||
qdrant_client_mock.return_value.search.assert_called_once_with(
|
||||
collection_name="embedchain-store-1526",
|
||||
@@ -119,7 +127,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
)
|
||||
]
|
||||
),
|
||||
query_vector=["This is a test document."],
|
||||
query_vector=[1, 2, 3],
|
||||
limit=1,
|
||||
)
|
||||
|
||||
@@ -128,6 +136,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
@@ -142,6 +151,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
|
||||
@@ -8,6 +8,11 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.weaviate import WeaviateDB
|
||||
|
||||
|
||||
def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
|
||||
"""A mock embedding function."""
|
||||
return [[1, 2, 3], [4, 5, 6]]
|
||||
|
||||
|
||||
class TestWeaviateDb(unittest.TestCase):
|
||||
def test_incorrect_config_throws_error(self):
|
||||
"""Test the init method of the WeaviateDb class throws error for incorrect config"""
|
||||
@@ -25,6 +30,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -92,6 +98,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -111,6 +118,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -122,8 +130,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
documents = ["This is a test document.", "This is another test document."]
|
||||
metadatas = [None, None]
|
||||
ids = ["123", "456"]
|
||||
skip_embedding = True
|
||||
db.add(embeddings, documents, metadatas, ids, skip_embedding)
|
||||
db.add(embeddings, documents, metadatas, ids)
|
||||
|
||||
# Check if the document was added to the database.
|
||||
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
|
||||
@@ -155,6 +162,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -162,12 +170,10 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={})
|
||||
|
||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
|
||||
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with(
|
||||
{"vector": ["This is a test document."]}
|
||||
)
|
||||
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_query_with_where(self, weaviate_mock):
|
||||
@@ -180,6 +186,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -187,15 +194,13 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
|
||||
|
||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
|
||||
weaviate_client_query_get_mock.with_where.assert_called_once_with(
|
||||
{"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
|
||||
)
|
||||
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with(
|
||||
{"vector": ["This is a test document."]}
|
||||
)
|
||||
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_reset(self, weaviate_mock):
|
||||
@@ -206,6 +211,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -228,6 +234,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
|
||||
@@ -108,65 +108,7 @@ class TestZillizDBCollection:
|
||||
|
||||
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
||||
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
|
||||
def test_query_with_skip_embedding(self, mock_connect, mock_client, mock_config):
|
||||
"""
|
||||
Test if the `ZillizVectorDB` instance is takes in the query with skip_embeddings.
|
||||
"""
|
||||
# Create an instance of ZillizVectorDB with mock config
|
||||
zilliz_db = ZillizVectorDB(config=mock_config)
|
||||
|
||||
# Add a 'collection' attribute to the ZillizVectorDB instance for testing
|
||||
zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
|
||||
|
||||
assert zilliz_db.client == mock_client()
|
||||
|
||||
# Mock the MilvusClient search method
|
||||
with patch.object(zilliz_db.client, "search") as mock_search:
|
||||
# Mock the search result
|
||||
mock_search.return_value = [
|
||||
[
|
||||
{
|
||||
"distance": 0.5,
|
||||
"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
# Call the query method with skip_embedding=True
|
||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
|
||||
|
||||
# Assert that MilvusClient.search was called with the correct parameters
|
||||
mock_search.assert_called_with(
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_text"],
|
||||
limit=1,
|
||||
output_fields=["*"],
|
||||
)
|
||||
|
||||
# Assert that the query result matches the expected result
|
||||
assert query_result == ["result_doc"]
|
||||
|
||||
query_result_with_citations = zilliz_db.query(
|
||||
input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True
|
||||
)
|
||||
|
||||
mock_search.assert_called_with(
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_text"],
|
||||
limit=1,
|
||||
output_fields=["*"],
|
||||
)
|
||||
|
||||
assert query_result_with_citations == [
|
||||
("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.5})
|
||||
]
|
||||
|
||||
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
||||
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
|
||||
def test_query_without_skip_embedding(self, mock_connect, mock_client, mock_embedder, mock_config):
|
||||
"""
|
||||
Test if the `ZillizVectorDB` instance is takes in the query without skip_embeddings.
|
||||
"""
|
||||
def test_query(self, mock_connect, mock_client, mock_embedder, mock_config):
|
||||
# Create an instance of ZillizVectorDB with mock config
|
||||
zilliz_db = ZillizVectorDB(config=mock_config)
|
||||
|
||||
@@ -193,8 +135,7 @@ class TestZillizDBCollection:
|
||||
]
|
||||
]
|
||||
|
||||
# Call the query method with skip_embedding=False
|
||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
|
||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={})
|
||||
|
||||
# Assert that MilvusClient.search was called with the correct parameters
|
||||
mock_search.assert_called_with(
|
||||
@@ -208,7 +149,7 @@ class TestZillizDBCollection:
|
||||
assert query_result == ["result_doc"]
|
||||
|
||||
query_result_with_citations = zilliz_db.query(
|
||||
input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True
|
||||
input_query=["query_text"], n_results=1, where={}, citations=True
|
||||
)
|
||||
|
||||
mock_search.assert_called_with(
|
||||
|
||||
Reference in New Issue
Block a user