[Refactor] Converge Pipeline and App classes (#1021)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
|
||||
@@ -15,7 +16,7 @@ os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
@pytest.fixture
|
||||
def app_instance():
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
return App(config)
|
||||
return App(config=config)
|
||||
|
||||
|
||||
def test_whole_app(app_instance, mocker):
|
||||
@@ -44,9 +45,9 @@ def test_add_after_reset(app_instance, mocker):
|
||||
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
chroma_config = {"allow_reset": True}
|
||||
|
||||
app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
|
||||
chroma_config = ChromaDbConfig(allow_reset=True)
|
||||
db = ChromaDB(config=chroma_config)
|
||||
app_instance = App(config=config, db=db)
|
||||
|
||||
# mock delete chat history
|
||||
mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
|
||||
|
||||
@@ -114,5 +114,7 @@ class TestApp(unittest.TestCase):
|
||||
self.assertEqual(answer, "Test answer")
|
||||
_args, kwargs = mock_database_query.call_args
|
||||
self.assertEqual(kwargs.get("input_query"), "Test query")
|
||||
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
|
||||
where = kwargs.get("where")
|
||||
assert "app_id" in where
|
||||
assert "attribute" in where
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -37,25 +38,14 @@ def test_query_config_app_passing(mock_get_answer):
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
|
||||
app = App(config=config, llm_config=chat_config)
|
||||
llm = OpenAILlm(config=chat_config)
|
||||
app = App(config=config, llm=llm)
|
||||
answer = app.llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert app.llm.config.system_prompt == "Test system prompt"
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
|
||||
def test_app_passing(mock_get_answer):
|
||||
mock_get_answer.return_value = MagicMock()
|
||||
mock_get_answer.return_value = "Test answer"
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig()
|
||||
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
|
||||
answer = app.llm.get_llm_model_answer("Test query")
|
||||
assert app.llm.config.system_prompt == "Test system prompt"
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_params(app):
|
||||
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
|
||||
@@ -83,5 +73,7 @@ def test_query_with_where_in_query_config(app):
|
||||
assert answer == "Test answer"
|
||||
_, kwargs = mock_database_query.call_args
|
||||
assert kwargs.get("input_query") == "Test query"
|
||||
assert kwargs.get("where") == {"attribute": "value"}
|
||||
where = kwargs.get("where")
|
||||
assert "app_id" in where
|
||||
assert "attribute" in where
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@@ -4,11 +4,10 @@ import pytest
|
||||
import yaml
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
|
||||
BaseLlmConfig, ChromaDbConfig)
|
||||
from embedchain.config import ChromaDbConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@@ -21,13 +20,14 @@ def app():
|
||||
def test_app(app):
|
||||
assert isinstance(app.llm, BaseLlm)
|
||||
assert isinstance(app.db, BaseVectorDB)
|
||||
assert isinstance(app.embedder, BaseEmbedder)
|
||||
assert isinstance(app.embedding_model, BaseEmbedder)
|
||||
|
||||
|
||||
class TestConfigForAppComponents:
|
||||
def test_constructor_config(self):
|
||||
collection_name = "my-test-collection"
|
||||
app = App(db_config=ChromaDbConfig(collection_name=collection_name))
|
||||
db = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
|
||||
app = App(db=db)
|
||||
assert app.db.config.collection_name == collection_name
|
||||
|
||||
def test_component_config(self):
|
||||
@@ -36,50 +36,6 @@ class TestConfigForAppComponents:
|
||||
app = App(db=database)
|
||||
assert app.db.config.collection_name == collection_name
|
||||
|
||||
def test_different_configs_are_proper_instances(self):
|
||||
app_config = AppConfig()
|
||||
wrong_config = AddConfig()
|
||||
with pytest.raises(TypeError):
|
||||
App(config=wrong_config)
|
||||
|
||||
assert isinstance(app_config, AppConfig)
|
||||
|
||||
llm_config = BaseLlmConfig()
|
||||
wrong_llm_config = "wrong_llm_config"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
App(llm_config=wrong_llm_config)
|
||||
|
||||
assert isinstance(llm_config, BaseLlmConfig)
|
||||
|
||||
db_config = BaseVectorDbConfig()
|
||||
wrong_db_config = "wrong_db_config"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
App(db_config=wrong_db_config)
|
||||
|
||||
assert isinstance(db_config, BaseVectorDbConfig)
|
||||
|
||||
embedder_config = BaseEmbedderConfig()
|
||||
wrong_embedder_config = "wrong_embedder_config"
|
||||
with pytest.raises(TypeError):
|
||||
App(embedder_config=wrong_embedder_config)
|
||||
|
||||
assert isinstance(embedder_config, BaseEmbedderConfig)
|
||||
|
||||
def test_components_raises_type_error_if_not_proper_instances(self):
|
||||
wrong_llm = "wrong_llm"
|
||||
with pytest.raises(TypeError):
|
||||
App(llm=wrong_llm)
|
||||
|
||||
wrong_db = "wrong_db"
|
||||
with pytest.raises(TypeError):
|
||||
App(db=wrong_db)
|
||||
|
||||
wrong_embedder = "wrong_embedder"
|
||||
with pytest.raises(TypeError):
|
||||
App(embedder=wrong_embedder)
|
||||
|
||||
|
||||
class TestAppFromConfig:
|
||||
def load_config_data(self, yaml_path):
|
||||
@@ -92,14 +48,13 @@ class TestAppFromConfig:
|
||||
yaml_path = "configs/chroma.yaml"
|
||||
config_data = self.load_config_data(yaml_path)
|
||||
|
||||
app = App.from_config(yaml_path)
|
||||
app = App.from_config(config_path=yaml_path)
|
||||
|
||||
# Check if the App instance and its components were created correctly
|
||||
assert isinstance(app, App)
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
|
||||
# Even though not present in the config, the default value is used
|
||||
assert app.config.collect_metrics is True
|
||||
|
||||
@@ -118,8 +73,8 @@ class TestAppFromConfig:
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedder.config.model == embedder_config["model"]
|
||||
assert app.embedder.config.deployment_name == embedder_config.get("deployment_name")
|
||||
assert app.embedding_model.config.model == embedder_config["model"]
|
||||
assert app.embedding_model.config.deployment_name == embedder_config.get("deployment_name")
|
||||
|
||||
def test_from_opensource_config(self, mocker):
|
||||
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
@@ -134,7 +89,6 @@ class TestAppFromConfig:
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
|
||||
assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
|
||||
|
||||
# Validate the LLM config values
|
||||
@@ -153,4 +107,4 @@ class TestAppFromConfig:
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
|
||||
assert app.embedding_model.config.deployment_name == embedder_config["deployment_name"]
|
||||
@@ -20,8 +20,9 @@ def chroma_db():
|
||||
@pytest.fixture
|
||||
def app_with_settings():
|
||||
chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
chroma_db = ChromaDB(config=chroma_config)
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
return App(config=app_config, db_config=chroma_config)
|
||||
return App(config=app_config, db=chroma_db)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -65,7 +66,8 @@ def test_app_init_with_host_and_port(mock_client):
|
||||
port = "1234"
|
||||
config = AppConfig(collect_metrics=False)
|
||||
db_config = ChromaDbConfig(host=host, port=port)
|
||||
_app = App(config, db_config=db_config)
|
||||
db = ChromaDB(config=db_config)
|
||||
_app = App(config=config, db=db)
|
||||
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
assert called_settings.chroma_server_host == host
|
||||
@@ -74,7 +76,8 @@ def test_app_init_with_host_and_port(mock_client):
|
||||
|
||||
@patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
def test_app_init_with_host_and_port_none(mock_client):
|
||||
_app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
_app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
assert called_settings.chroma_server_host is None
|
||||
@@ -82,7 +85,8 @@ def test_app_init_with_host_and_port_none(mock_client):
|
||||
|
||||
|
||||
def test_chroma_db_duplicates_throw_warning(caplog):
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
assert "Insert of existing embedding ID: 0" in caplog.text
|
||||
@@ -91,7 +95,8 @@ def test_chroma_db_duplicates_throw_warning(caplog):
|
||||
|
||||
|
||||
def test_chroma_db_duplicates_collections_no_warning(caplog):
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.set_collection_name("test_collection_2")
|
||||
@@ -104,24 +109,28 @@ def test_chroma_db_duplicates_collections_no_warning(caplog):
|
||||
|
||||
|
||||
def test_chroma_db_collection_init_with_default_collection():
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
assert app.db.collection.name == "embedchain_store"
|
||||
|
||||
|
||||
def test_chroma_db_collection_init_with_custom_collection():
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name(name="test_collection")
|
||||
assert app.db.collection.name == "test_collection"
|
||||
|
||||
|
||||
def test_chroma_db_collection_set_collection_name():
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection")
|
||||
assert app.db.collection.name == "test_collection"
|
||||
|
||||
|
||||
def test_chroma_db_collection_changes_encapsulated():
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 0
|
||||
|
||||
@@ -207,12 +216,14 @@ def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
|
||||
|
||||
|
||||
def test_chroma_db_collection_collections_are_persistent():
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
del app
|
||||
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 1
|
||||
|
||||
@@ -220,13 +231,15 @@ def test_chroma_db_collection_collections_are_persistent():
|
||||
|
||||
|
||||
def test_chroma_db_collection_parallel_collections():
|
||||
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
|
||||
app1 = App(
|
||||
AppConfig(collection_name="test_collection_1", collect_metrics=False),
|
||||
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
|
||||
config=AppConfig(collect_metrics=False),
|
||||
db=db1,
|
||||
)
|
||||
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
|
||||
app2 = App(
|
||||
AppConfig(collection_name="test_collection_2", collect_metrics=False),
|
||||
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
|
||||
config=AppConfig(collect_metrics=False),
|
||||
db=db2,
|
||||
)
|
||||
|
||||
# cleanup if any previous tests failed or were interrupted
|
||||
@@ -251,13 +264,11 @@ def test_chroma_db_collection_parallel_collections():
|
||||
|
||||
|
||||
def test_chroma_db_collection_ids_share_collections():
|
||||
app1 = App(
|
||||
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
|
||||
app1.set_collection_name("one_collection")
|
||||
app2 = App(
|
||||
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
|
||||
app2.set_collection_name("one_collection")
|
||||
|
||||
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
|
||||
@@ -272,21 +283,17 @@ def test_chroma_db_collection_ids_share_collections():
|
||||
|
||||
|
||||
def test_chroma_db_collection_reset():
|
||||
app1 = App(
|
||||
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
|
||||
app1.set_collection_name("one_collection")
|
||||
app2 = App(
|
||||
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
|
||||
app2.set_collection_name("two_collection")
|
||||
app3 = App(
|
||||
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db3 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app3 = App(config=AppConfig(collect_metrics=False), db=db3)
|
||||
app3.set_collection_name("three_collection")
|
||||
app4 = App(
|
||||
AppConfig(id="new_app_id_4", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db4 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app4 = App(config=AppConfig(collect_metrics=False), db=db4)
|
||||
app4.set_collection_name("four_collection")
|
||||
|
||||
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
|
||||
|
||||
@@ -13,7 +13,7 @@ class TestEsDB(unittest.TestCase):
|
||||
def test_setUp(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
self.vector_dim = 384
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db)
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
@@ -22,8 +22,8 @@ class TestEsDB(unittest.TestCase):
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db, embedder=GPT4AllEmbedder())
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db, embedding_model=GPT4AllEmbedder())
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
@@ -74,7 +74,7 @@ class TestEsDB(unittest.TestCase):
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query_with_skip_embedding(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db)
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestPinecone:
|
||||
# Create a PineconeDB instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Assert that the embedder was set
|
||||
assert db.embedder == embedder
|
||||
@@ -48,7 +48,7 @@ class TestPinecone:
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=base_embedder)
|
||||
App(config=app_config, db=db, embedding_model=base_embedder)
|
||||
|
||||
# Add some documents to the database
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
@@ -76,7 +76,7 @@ class TestPinecone:
|
||||
# Create a PineconeDB instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=base_embedder)
|
||||
App(config=app_config, db=db, embedding_model=base_embedder)
|
||||
|
||||
# Query the database for documents that are similar to "document"
|
||||
input_query = ["document"]
|
||||
@@ -94,7 +94,7 @@ class TestPinecone:
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=BaseEmbedder())
|
||||
App(config=app_config, db=db, embedding_model=BaseEmbedder())
|
||||
|
||||
# Reset the database
|
||||
db.reset()
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
self.assertEqual(db.collection_name, "embedchain-store-1526")
|
||||
self.assertEqual(db.client, qdrant_client_mock.return_value)
|
||||
@@ -46,7 +46,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
resp = db.get(ids=[], where={})
|
||||
self.assertEqual(resp, {"ids": []})
|
||||
@@ -65,7 +65,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
||||
documents = ["This is a test document.", "This is another test document."]
|
||||
@@ -76,7 +76,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
qdrant_client_mock.return_value.upsert.assert_called_once_with(
|
||||
collection_name="embedchain-store-1526",
|
||||
points=Batch(
|
||||
ids=["abc", "def"],
|
||||
ids=["def", "ghi"],
|
||||
payloads=[
|
||||
{
|
||||
"identifier": "123",
|
||||
@@ -102,7 +102,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
|
||||
@@ -132,7 +132,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
db.count()
|
||||
qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526")
|
||||
@@ -146,7 +146,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
db.reset()
|
||||
qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
expected_class_obj = {
|
||||
"classes": [
|
||||
@@ -96,7 +96,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
expected_client = db._get_or_create_db()
|
||||
self.assertEqual(expected_client, weaviate_client_mock)
|
||||
@@ -115,7 +115,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
db.BATCH_SIZE = 1
|
||||
|
||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||
@@ -159,7 +159,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
|
||||
@@ -184,7 +184,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
|
||||
@@ -210,7 +210,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Reset the database.
|
||||
db.reset()
|
||||
@@ -232,7 +232,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Reset the database.
|
||||
db.count()
|
||||
|
||||
Reference in New Issue
Block a user