Add GPT4Vision Image loader (#1089)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -8,6 +8,11 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.weaviate import WeaviateDB
|
||||
|
||||
|
||||
def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
|
||||
"""A mock embedding function."""
|
||||
return [[1, 2, 3], [4, 5, 6]]
|
||||
|
||||
|
||||
class TestWeaviateDb(unittest.TestCase):
|
||||
def test_incorrect_config_throws_error(self):
|
||||
"""Test the init method of the WeaviateDb class throws error for incorrect config"""
|
||||
@@ -25,6 +30,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -92,6 +98,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -111,6 +118,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -122,8 +130,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
documents = ["This is a test document.", "This is another test document."]
|
||||
metadatas = [None, None]
|
||||
ids = ["123", "456"]
|
||||
skip_embedding = True
|
||||
db.add(embeddings, documents, metadatas, ids, skip_embedding)
|
||||
db.add(embeddings, documents, metadatas, ids)
|
||||
|
||||
# Check if the document was added to the database.
|
||||
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
|
||||
@@ -155,6 +162,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -162,12 +170,10 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={})
|
||||
|
||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
|
||||
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with(
|
||||
{"vector": ["This is a test document."]}
|
||||
)
|
||||
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_query_with_where(self, weaviate_mock):
|
||||
@@ -180,6 +186,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -187,15 +194,13 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
|
||||
|
||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
|
||||
weaviate_client_query_get_mock.with_where.assert_called_once_with(
|
||||
{"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
|
||||
)
|
||||
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with(
|
||||
{"vector": ["This is a test document."]}
|
||||
)
|
||||
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_reset(self, weaviate_mock):
|
||||
@@ -206,6 +211,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
@@ -228,6 +234,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
|
||||
Reference in New Issue
Block a user