[Refactor] Converge Pipeline and App classes (#1021)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-29 16:52:41 +05:30
committed by GitHub
parent c0aafd38c9
commit a926bcc640
91 changed files with 646 additions and 875 deletions

View File

@@ -20,8 +20,9 @@ def chroma_db():
@pytest.fixture
def app_with_settings():
chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
chroma_db = ChromaDB(config=chroma_config)
app_config = AppConfig(collect_metrics=False)
return App(config=app_config, db_config=chroma_config)
return App(config=app_config, db=chroma_db)
@pytest.fixture(scope="session", autouse=True)
@@ -65,7 +66,8 @@ def test_app_init_with_host_and_port(mock_client):
port = "1234"
config = AppConfig(collect_metrics=False)
db_config = ChromaDbConfig(host=host, port=port)
_app = App(config, db_config=db_config)
db = ChromaDB(config=db_config)
_app = App(config=config, db=db)
called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host == host
@@ -74,7 +76,8 @@ def test_app_init_with_host_and_port(mock_client):
@patch("embedchain.vectordb.chroma.chromadb.Client")
def test_app_init_with_host_and_port_none(mock_client):
_app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
_app = App(config=AppConfig(collect_metrics=False), db=db)
called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host is None
@@ -82,7 +85,8 @@ def test_app_init_with_host_and_port_none(mock_client):
def test_chroma_db_duplicates_throw_warning(caplog):
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" in caplog.text
@@ -91,7 +95,8 @@ def test_chroma_db_duplicates_throw_warning(caplog):
def test_chroma_db_duplicates_collections_no_warning(caplog):
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection_name("test_collection_2")
@@ -104,24 +109,28 @@ def test_chroma_db_duplicates_collections_no_warning(caplog):
def test_chroma_db_collection_init_with_default_collection():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
assert app.db.collection.name == "embedchain_store"
def test_chroma_db_collection_init_with_custom_collection():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name(name="test_collection")
assert app.db.collection.name == "test_collection"
def test_chroma_db_collection_set_collection_name():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection")
assert app.db.collection.name == "test_collection"
def test_chroma_db_collection_changes_encapsulated():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
assert app.db.count() == 0
@@ -207,12 +216,14 @@ def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
def test_chroma_db_collection_collections_are_persistent():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
del app
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
assert app.db.count() == 1
@@ -220,13 +231,15 @@ def test_chroma_db_collection_collections_are_persistent():
def test_chroma_db_collection_parallel_collections():
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
app1 = App(
AppConfig(collection_name="test_collection_1", collect_metrics=False),
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
config=AppConfig(collect_metrics=False),
db=db1,
)
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
app2 = App(
AppConfig(collection_name="test_collection_2", collect_metrics=False),
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
config=AppConfig(collect_metrics=False),
db=db2,
)
# cleanup if any previous tests failed or were interrupted
@@ -251,13 +264,11 @@ def test_chroma_db_collection_parallel_collections():
def test_chroma_db_collection_ids_share_collections():
app1 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
app1.set_collection_name("one_collection")
app2 = App(
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
app2.set_collection_name("one_collection")
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
@@ -272,21 +283,17 @@ def test_chroma_db_collection_ids_share_collections():
def test_chroma_db_collection_reset():
app1 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
app1.set_collection_name("one_collection")
app2 = App(
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
app2.set_collection_name("two_collection")
app3 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db3 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app3 = App(config=AppConfig(collect_metrics=False), db=db3)
app3.set_collection_name("three_collection")
app4 = App(
AppConfig(id="new_app_id_4", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db4 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app4 = App(config=AppConfig(collect_metrics=False), db=db4)
app4.set_collection_name("four_collection")
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])

View File

@@ -13,7 +13,7 @@ class TestEsDB(unittest.TestCase):
def test_setUp(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
self.vector_dim = 384
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_config = AppConfig(collect_metrics=False)
self.app = App(config=app_config, db=self.db)
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
@@ -22,8 +22,8 @@ class TestEsDB(unittest.TestCase):
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
app_config = AppConfig(collection_name=False, collect_metrics=False)
self.app = App(config=app_config, db=self.db, embedder=GPT4AllEmbedder())
app_config = AppConfig(collect_metrics=False)
self.app = App(config=app_config, db=self.db, embedding_model=GPT4AllEmbedder())
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
@@ -74,7 +74,7 @@ class TestEsDB(unittest.TestCase):
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query_with_skip_embedding(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_config = AppConfig(collect_metrics=False)
self.app = App(config=app_config, db=self.db)
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.

View File

@@ -29,7 +29,7 @@ class TestPinecone:
# Create a PineconeDB instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Assert that the embedder was set
assert db.embedder == embedder
@@ -48,7 +48,7 @@ class TestPinecone:
# Create a PineconeDb instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder)
App(config=app_config, db=db, embedding_model=base_embedder)
# Add some documents to the database
documents = ["This is a document.", "This is another document."]
@@ -76,7 +76,7 @@ class TestPinecone:
# Create a PineconeDB instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder)
App(config=app_config, db=db, embedding_model=base_embedder)
# Query the database for documents that are similar to "document"
input_query = ["document"]
@@ -94,7 +94,7 @@ class TestPinecone:
# Create a PineconeDb instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=BaseEmbedder())
App(config=app_config, db=db, embedding_model=BaseEmbedder())
# Reset the database
db.reset()

View File

@@ -29,7 +29,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
self.assertEqual(db.collection_name, "embedchain-store-1526")
self.assertEqual(db.client, qdrant_client_mock.return_value)
@@ -46,7 +46,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
resp = db.get(ids=[], where={})
self.assertEqual(resp, {"ids": []})
@@ -65,7 +65,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
documents = ["This is a test document.", "This is another test document."]
@@ -76,7 +76,7 @@ class TestQdrantDB(unittest.TestCase):
qdrant_client_mock.return_value.upsert.assert_called_once_with(
collection_name="embedchain-store-1526",
points=Batch(
ids=["abc", "def"],
ids=["def", "ghi"],
payloads=[
{
"identifier": "123",
@@ -102,7 +102,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
@@ -132,7 +132,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
db.count()
qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526")
@@ -146,7 +146,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
db.reset()
qdrant_client_mock.return_value.delete_collection.assert_called_once_with(

View File

@@ -29,7 +29,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
expected_class_obj = {
"classes": [
@@ -96,7 +96,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
expected_client = db._get_or_create_db()
self.assertEqual(expected_client, weaviate_client_mock)
@@ -115,7 +115,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
db.BATCH_SIZE = 1
embeddings = [[1, 2, 3], [4, 5, 6]]
@@ -159,7 +159,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
@@ -184,7 +184,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
@@ -210,7 +210,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Reset the database.
db.reset()
@@ -232,7 +232,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Reset the database.
db.count()