[Bug fix] Fix embedding issue for opensearch and some other vector databases (#1163)

This commit is contained in:
Deshraj Yadav
2024-01-12 14:15:39 +05:30
committed by GitHub
parent c020e65a50
commit 862ff6cca6
13 changed files with 40 additions and 95 deletions

View File

@@ -28,14 +28,13 @@ class TestEsDB(unittest.TestCase):
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
# Create some dummy data.
embeddings = [[1, 2, 3], [4, 5, 6]]
# Create some dummy data
documents = ["This is a document.", "This is another document."]
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
ids = ["doc_1", "doc_2"]
# Add the data to the database.
self.db.add(embeddings, documents, metadatas, ids)
self.db.add(documents, metadatas, ids)
search_response = {
"hits": {

View File

@@ -43,8 +43,8 @@ class TestPinecone:
embedding_function = mock.Mock()
base_embedder = BaseEmbedder()
base_embedder.set_embedding_fn(embedding_function)
vectors = [[0, 0, 0], [1, 1, 1]]
embedding_function.return_value = vectors
embedding_function.return_value = [[0, 0, 0], [1, 1, 1]]
# Create a PineconeDb instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
@@ -54,7 +54,7 @@ class TestPinecone:
documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}]
ids = ["doc1", "doc2"]
db.add(vectors, documents, metadatas, ids)
db.add(documents, metadatas, ids)
expected_pinecone_upsert_args = [
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},

View File

@@ -75,11 +75,10 @@ class TestQdrantDB(unittest.TestCase):
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
documents = ["This is a test document.", "This is another test document."]
metadatas = [{}, {}]
ids = ["123", "456"]
db.add(embeddings, documents, metadatas, ids)
db.add(documents, metadatas, ids)
qdrant_client_mock.return_value.upsert.assert_called_once_with(
collection_name="embedchain-store-1526",
points=Batch(
@@ -96,7 +95,7 @@ class TestQdrantDB(unittest.TestCase):
"metadata": {"text": "This is another test document."},
},
],
vectors=embeddings,
vectors=[[1, 2, 3], [4, 5, 6]],
),
)

View File

@@ -29,7 +29,7 @@ class TestWeaviateDb(unittest.TestCase):
weaviate_client_schema_mock.exists.return_value = False
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
@@ -40,7 +40,7 @@ class TestWeaviateDb(unittest.TestCase):
expected_class_obj = {
"classes": [
{
"class": "Embedchain_store_1526",
"class": "Embedchain_store_1536",
"vectorizer": "none",
"properties": [
{
@@ -53,12 +53,12 @@ class TestWeaviateDb(unittest.TestCase):
},
{
"name": "metadata",
"dataType": ["Embedchain_store_1526_metadata"],
"dataType": ["Embedchain_store_1536_metadata"],
},
],
},
{
"class": "Embedchain_store_1526_metadata",
"class": "Embedchain_store_1536_metadata",
"vectorizer": "none",
"properties": [
{
@@ -88,7 +88,7 @@ class TestWeaviateDb(unittest.TestCase):
# Assert that the Weaviate client was initialized
weaviate_mock.Client.assert_called_once()
self.assertEqual(db.index_name, "Embedchain_store_1526")
self.assertEqual(db.index_name, "Embedchain_store_1536")
weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
@patch("embedchain.vectordb.weaviate.weaviate")
@@ -97,7 +97,7 @@ class TestWeaviateDb(unittest.TestCase):
weaviate_client_mock = weaviate_mock.Client.return_value
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
@@ -117,7 +117,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
@@ -126,30 +126,21 @@ class TestWeaviateDb(unittest.TestCase):
App(config=app_config, db=db, embedding_model=embedder)
db.BATCH_SIZE = 1
embeddings = [[1, 2, 3], [4, 5, 6]]
documents = ["This is a test document.", "This is another test document."]
metadatas = [None, None]
ids = ["123", "456"]
db.add(embeddings, documents, metadatas, ids)
documents = ["This is test document"]
metadatas = [None]
ids = ["id_1"]
db.add(documents, metadatas, ids)
# Check if the document was added to the database.
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
data_object={"text": documents[0]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[0]
)
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
data_object={"text": documents[1]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[1]
data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
)
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
data_object={"identifier": ids[0], "text": documents[0]},
class_name="Embedchain_store_1526",
vector=embeddings[0],
)
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
data_object={"identifier": ids[1], "text": documents[1]},
class_name="Embedchain_store_1526",
vector=embeddings[1],
data_object={"text": documents[0]},
class_name="Embedchain_store_1536_metadata",
vector=[1, 2, 3],
)
@patch("embedchain.vectordb.weaviate.weaviate")
@@ -161,7 +152,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
@@ -172,7 +163,7 @@ class TestWeaviateDb(unittest.TestCase):
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={})
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
@patch("embedchain.vectordb.weaviate.weaviate")
@@ -185,7 +176,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
@@ -196,9 +187,9 @@ class TestWeaviateDb(unittest.TestCase):
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
weaviate_client_query_get_mock.with_where.assert_called_once_with(
{"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
{"operator": "Equal", "path": ["metadata", "Embedchain_store_1536_metadata", "doc_id"], "valueText": "123"}
)
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
@@ -210,7 +201,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
@@ -222,7 +213,7 @@ class TestWeaviateDb(unittest.TestCase):
db.reset()
weaviate_client_batch_mock.delete_objects.assert_called_once_with(
"Embedchain_store_1526", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
"Embedchain_store_1536", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
)
@patch("embedchain.vectordb.weaviate.weaviate")
@@ -233,7 +224,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
@@ -244,4 +235,4 @@ class TestWeaviateDb(unittest.TestCase):
# Reset the database.
db.count()
weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1526")
weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1536")