diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 53c3d57f..2fcbbb07 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,9 +23,9 @@ jobs: - name: Install dependencies run: poetry install --all-extras - name: Lint with ruff - run: make ci_lint + run: make lint - name: Test with pytest - run: make ci_test + run: make test - name: Generate coverage report run: make coverage - name: Upload coverage reports to Codecov diff --git a/Makefile b/Makefile index e2bfad01..217c6bb1 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ install_opensearch: install_milvus: poetry install --extras milvus - + shell: poetry shell @@ -31,19 +31,13 @@ format: $(PYTHON) -m black . $(PYTHON) -m isort . -lint: - $(PYTHON) -m ruff . - clean: rm -rf dist build *.egg-info -test: - $(PYTHON) -m pytest - -ci_lint: +lint: poetry run ruff . -ci_test: +test: poetry run pytest coverage: diff --git a/embedchain/__init__.py b/embedchain/__init__.py index b26a8f24..a0eb7de4 100644 --- a/embedchain/__init__.py +++ b/embedchain/__init__.py @@ -7,5 +7,5 @@ from embedchain.apps.custom_app import CustomApp # noqa: F401 from embedchain.apps.Llama2App import Llama2App # noqa: F401 from embedchain.apps.open_source_app import OpenSourceApp # noqa: F401 from embedchain.apps.person_app import (PersonApp, # noqa: F401 - PersonOpenSourceApp) + PersonOpenSourceApp) from embedchain.vectordb.chroma import ChromaDB # noqa: F401 diff --git a/embedchain/apps/app.py b/embedchain/apps/app.py index 9148fec8..18f03c8e 100644 --- a/embedchain/apps/app.py +++ b/embedchain/apps/app.py @@ -1,12 +1,13 @@ -import logging from typing import Optional -from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig, - ChromaDbConfig) +import yaml + +from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.embedchain import EmbedChain from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.openai import OpenAIEmbedder +from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory from embedchain.helper.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm from embedchain.llm.openai import OpenAILlm @@ -35,7 +36,6 @@ class App(EmbedChain): db_config: Optional[BaseVectorDbConfig] = None, embedder: BaseEmbedder = None, embedder_config: Optional[BaseEmbedderConfig] = None, - chromadb_config: Optional[ChromaDbConfig] = None, system_prompt: Optional[str] = None, ): """ @@ -60,20 +60,10 @@ class App(EmbedChain): :param embedder_config: Allows you to configure the Embedder. example: `from embedchain.config import BaseEmbedderConfig`, defaults to None :type embedder_config: Optional[BaseEmbedderConfig], optional - :param chromadb_config: Deprecated alias of `db_config`, defaults to None - :type chromadb_config: Optional[ChromaDbConfig], optional :param system_prompt: System prompt that will be provided to the LLM as such, defaults to None :type system_prompt: Optional[str], optional :raises TypeError: LLM, database or embedder or their config is not a valid class instance. """ - # Overwrite deprecated arguments - if chromadb_config: - logging.warning( - "DEPRECATION WARNING: Please use `db_config` argument instead of `chromadb_config`." - "`chromadb_config` will be removed in a future release." - ) - db_config = chromadb_config - # Type check configs if config and not isinstance(config, AppConfig): raise TypeError( @@ -123,3 +113,33 @@ class App(EmbedChain): "Please make sure the type is right and that you are passing an instance." ) super().__init__(config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt) + + @classmethod + def from_config(cls, yaml_path: str): + """ + Instantiate an App object from a YAML configuration file. + + :param yaml_path: Path to the YAML configuration file. + :type yaml_path: str + :return: An instance of the App class. + :rtype: App + """ + with open(yaml_path, "r") as file: + config_data = yaml.safe_load(file) + + app_config_data = config_data.get("app", {}) + llm_config_data = config_data.get("llm", {}) + db_config_data = config_data.get("vectordb", {}) + embedder_config_data = config_data.get("embedder", {}) + + app_config = AppConfig(**app_config_data.get("config", {})) + + llm_provider = llm_config_data.get("provider", "openai") + llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {})) + + db_provider = db_config_data.get("provider", "chroma") + db = VectorDBFactory.create(db_provider, db_config_data.get("config", {})) + + embedder_provider = embedder_config_data.get("provider", "openai") + embedder = EmbedderFactory.create(embedder_provider, embedder_config_data.get("config", {})) + return cls(config=app_config, llm=llm, db=db, embedder=embedder) diff --git a/embedchain/bots/base.py b/embedchain/bots/base.py index 5be27121..03279857 100644 --- a/embedchain/bots/base.py +++ b/embedchain/bots/base.py @@ -3,7 +3,8 @@ from typing import Any from embedchain import App from embedchain.config import AddConfig, AppConfig, LlmConfig from embedchain.embedder.openai import OpenAIEmbedder -from embedchain.helper.json_serializable import JSONSerializable, register_deserializable +from embedchain.helper.json_serializable import (JSONSerializable, + register_deserializable) from embedchain.llm.openai import OpenAILlm from embedchain.vectordb.chroma import ChromaDB diff --git a/embedchain/config/vectordb/base.py b/embedchain/config/vectordb/base.py index d0bdbd8a..3252880a 100644 --- a/embedchain/config/vectordb/base.py +++ b/embedchain/config/vectordb/base.py @@ -10,6 +10,7 @@ class BaseVectorDbConfig(BaseConfig): dir: str = "db", host: Optional[str] = None, port: Optional[str] = None, + **kwargs, ): """ Initializes a configuration class instance for the vector database. @@ -22,8 +23,14 @@ class BaseVectorDbConfig(BaseConfig): :type host: Optional[str], optional :param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None :type port: Optional[str], optional + :param kwargs: Additional keyword arguments + :type kwargs: dict """ self.collection_name = collection_name or "embedchain_store" self.dir = dir self.host = host self.port = port + # Assign additional keyword arguments + if kwargs: + for key, value in kwargs.items(): + setattr(self, key, value) diff --git a/embedchain/factory.py b/embedchain/factory.py new file mode 100644 index 00000000..900695b8 --- /dev/null +++ b/embedchain/factory.py @@ -0,0 +1,88 @@ +import importlib + + +def load_class(class_type): + module_path, class_name = class_type.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +class LlmFactory: + provider_to_class = { + "anthropic": "embedchain.llm.anthropic.AnthropicLlm", + "azure_openai": "embedchain.llm.azure_openai.AzureOpenAILlm", + "cohere": "embedchain.llm.cohere.CohereLlm", + "gpt4all": "embedchain.llm.gpt4all.GPT4ALLLlm", + "hugging_face_llm": "embedchain.llm.hugging_face_llm.HuggingFaceLlm", + "jina": "embedchain.llm.jina.JinaLlm", + "llama2": "embedchain.llm.llama2.Llama2Llm", + "openai": "embedchain.llm.openai.OpenAILlm", + "vertexai": "embedchain.llm.vertex_ai.VertexAILlm", + } + provider_to_config_class = { + "embedchain": "embedchain.config.llm.base_llm_config.BaseLlmConfig", + "openai": "embedchain.config.llm.base_llm_config.BaseLlmConfig", + "anthropic": "embedchain.config.llm.base_llm_config.BaseLlmConfig", + } + + @classmethod + def create(cls, provider_name, config_data): + class_type = cls.provider_to_class.get(provider_name) + # Default to embedchain base config if the provider is not in the config map + config_name = "embedchain" if provider_name not in cls.provider_to_config_class else provider_name + config_class_type = cls.provider_to_config_class.get(config_name) + if class_type: + llm_class = load_class(class_type) + llm_config_class = load_class(config_class_type) + return llm_class(config=llm_config_class(**config_data)) + else: + raise ValueError(f"Unsupported Llm provider: {provider_name}") + + +class EmbedderFactory: + provider_to_class = { + "gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder", + "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder", + "vertexai": "embedchain.embedder.vertexai.VertexAiEmbedder", + "openai": "embedchain.embedder.openai.OpenAIEmbedder", + } + provider_to_config_class = { + "openai": "embedchain.config.embedder.base.BaseEmbedderConfig", + } + + @classmethod + def create(cls, provider_name, config_data): + class_type = cls.provider_to_class.get(provider_name) + # Default to openai config if the provider is not in the config map + config_name = "openai" if provider_name not in cls.provider_to_config_class else provider_name + config_class_type = cls.provider_to_config_class.get(config_name) + if class_type: + embedder_class = load_class(class_type) + embedder_config_class = load_class(config_class_type) + return embedder_class(config=embedder_config_class(**config_data)) + else: + raise ValueError(f"Unsupported Embedder provider: {provider_name}") + + +class VectorDBFactory: + provider_to_class = { + "chroma": "embedchain.vectordb.chroma.ChromaDB", + "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB", + "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB", + } + provider_to_config_class = { + "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig", + "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig", + "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig", + } + + @classmethod + def create(cls, provider_name, config_data): + class_type = cls.provider_to_class.get(provider_name) + config_class_type = cls.provider_to_config_class.get(provider_name) + if class_type: + embedder_class = load_class(class_type) + embedder_config_class = load_class(config_class_type) + return embedder_class(config=embedder_config_class(**config_data)) + else: + raise ValueError(f"Unsupported Embedder provider: {provider_name}") diff --git a/embedchain/llm/antrophic.py b/embedchain/llm/anthropic.py similarity index 90% rename from embedchain/llm/antrophic.py rename to embedchain/llm/anthropic.py index 08f653df..027da1ce 100644 --- a/embedchain/llm/antrophic.py +++ b/embedchain/llm/anthropic.py @@ -7,12 +7,12 @@ from embedchain.llm.base import BaseLlm @register_deserializable -class AntrophicLlm(BaseLlm): +class AnthropicLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): super().__init__(config=config) def get_llm_model_answer(self, prompt): - return AntrophicLlm._get_answer(prompt=prompt, config=self.config) + return AnthropicLlm._get_answer(prompt=prompt, config=self.config) @staticmethod def _get_answer(prompt: str, config: BaseLlmConfig) -> str: diff --git a/embedchain/llm/jina.py b/embedchain/llm/jina.py index 2c906c7e..2af5a798 100644 --- a/embedchain/llm/jina.py +++ b/embedchain/llm/jina.py @@ -34,7 +34,8 @@ class JinaLlm(BaseLlm): if config.top_p: kwargs["model_kwargs"]["top_p"] = config.top_p if config.stream: - from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + from langchain.callbacks.streaming_stdout import \ + StreamingStdOutCallbackHandler chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) else: diff --git a/embedchain/models/clip_processor.py b/embedchain/models/clip_processor.py index e349f4ae..46a89c16 100644 --- a/embedchain/models/clip_processor.py +++ b/embedchain/models/clip_processor.py @@ -2,9 +2,7 @@ try: from PIL import Image, UnidentifiedImageError from sentence_transformers import SentenceTransformer except ImportError: - raise ImportError( - "Images requires extra dependencies. Install with `pip install 'embedchain[images]'" - ) from None + raise ImportError("Images requires extra dependencies. Install with `pip install 'embedchain[images]'") from None MODEL_NAME = "clip-ViT-B-32" diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 1dfadd9b..be4e8604 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -37,7 +37,7 @@ class ChromaDB(BaseVectorDB): self.config = ChromaDbConfig() self.settings = Settings() - self.settings.allow_reset = self.config.allow_reset + self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False if self.config.chroma_settings: for key, value in self.config.chroma_settings.items(): if hasattr(self.settings, key): @@ -72,6 +72,17 @@ class ChromaDB(BaseVectorDB): """Called during initialization""" return self.client + def _generate_where_clause(self, where: Dict[str, any]) -> str: + # If only one filter is supplied, return it as is + # (no need to wrap in $and based on chroma docs) + if len(where.keys()) == 1: + return where + where_filters = [] + for k, v in where.items(): + if isinstance(v, str): + where_filters.append({k: v}) + return {"$and": where_filters} + def _get_or_create_collection(self, name: str) -> Collection: """ Get or create a named collection. @@ -107,13 +118,14 @@ class ChromaDB(BaseVectorDB): if ids: args["ids"] = ids if where: - args["where"] = where + args["where"] = self._generate_where_clause(where) if limit: args["limit"] = limit return self.collection.get(**args) def get_advanced(self, where): - return self.collection.get(where=where, limit=1) + where_clause = self._generate_where_clause(where) + return self.collection.get(where=where_clause, limit=1) def add( self, diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index edfa69f8..3857300b 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -110,8 +110,13 @@ class OpenSearchDB(BaseVectorDB): return result def add( - self, embeddings: List[List[str]], documents: List[str], metadatas: List[object], ids: List[str], - skip_embedding: bool): + self, + embeddings: List[List[str]], + documents: List[str], + metadatas: List[object], + ids: List[str], + skip_embedding: bool, + ): """add data in vector database :param embeddings: list of embeddings to add @@ -162,7 +167,8 @@ class OpenSearchDB(BaseVectorDB): embedding_function=embeddings, opensearch_url=f"{self.config.opensearch_url}", http_auth=self.config.http_auth, - use_ssl=True, + use_ssl=hasattr(self.config, "use_ssl") and self.config.use_ssl, + verify_certs=hasattr(self.config, "verify_certs") and self.config.verify_certs, ) pre_filter = {"match_all": {}} # default diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index dcd05cef..eb99ff2a 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -5,8 +5,8 @@ from embedchain.helper.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB try: - from pymilvus import MilvusClient - from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility + from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema, + MilvusClient, connections, utility) except ImportError: raise ImportError( "Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`" diff --git a/embedchain/yaml/chroma.yaml b/embedchain/yaml/chroma.yaml new file mode 100644 index 00000000..7b340b9e --- /dev/null +++ b/embedchain/yaml/chroma.yaml @@ -0,0 +1,26 @@ +app: + config: + id: 'my-app' + collection_name: 'my-app' + +llm: + provider: openai + model: 'gpt-3.5-turbo' + config: + temperature: 0.5 + max_tokens: 1000 + top_p: 1 + stream: false + +vectordb: + provider: chroma + config: + collection_name: 'my-app' + dir: db + allow_reset: true + +embedder: + provider: openai + config: + model: 'text-embedding-ada-002' + deployment_name: null diff --git a/embedchain/yaml/opensearch.yaml b/embedchain/yaml/opensearch.yaml new file mode 100644 index 00000000..07cf1c2d --- /dev/null +++ b/embedchain/yaml/opensearch.yaml @@ -0,0 +1,33 @@ +app: + config: + id: 'my-app' + log_level: 'WARN' + collect_metrics: true + collection_name: 'my-app' + +llm: + provider: openai + model: 'gpt-3.5-turbo' + config: + temperature: 0.5 + max_tokens: 1000 + top_p: 1 + stream: false + +vectordb: + provider: opensearch + config: + opensearch_url: 'https://localhost:9200' + http_auth: + - admin + - admin + vector_dimension: 1536 + collection_name: 'my-app' + use_ssl: false + verify_certs: false + +embedder: + provider: openai + config: + model: 'text-embedding-ada-002' + deployment_name: null diff --git a/embedchain/yaml/opensource.yaml b/embedchain/yaml/opensource.yaml new file mode 100644 index 00000000..5e8e9c60 --- /dev/null +++ b/embedchain/yaml/opensource.yaml @@ -0,0 +1,27 @@ +app: + config: + id: 'open-source-app' + collection_name: 'open-source-app' + collect_metrics: false + +llm: + provider: gpt4all + model: 'orca-mini-3b.ggmlv3.q4_0.bin' + config: + temperature: 0.5 + max_tokens: 1000 + top_p: 1 + stream: false + +vectordb: + provider: chroma + config: + collection_name: 'open-source-app' + dir: db + allow_reset: true + +embedder: + provider: gpt4all + config: + model: 'all-MiniLM-L6-v2' + deployment_name: null diff --git a/tests/apps/test_apps.py b/tests/apps/test_apps.py index 0b2ad263..f046b560 100644 --- a/tests/apps/test_apps.py +++ b/tests/apps/test_apps.py @@ -1,8 +1,11 @@ import os import unittest +import yaml + from embedchain import App, CustomApp, Llama2App, OpenSourceApp -from embedchain.config import ChromaDbConfig, AppConfig, AddConfig, BaseLlmConfig, BaseEmbedderConfig +from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig, + BaseLlmConfig, ChromaDbConfig) from embedchain.embedder.base import BaseEmbedder from embedchain.llm.base import BaseLlm from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig @@ -100,3 +103,74 @@ class TestConfigForAppComponents(unittest.TestCase): App(embedder_config=wrong_embedder_config) self.assertIsInstance(embedder_config, BaseEmbedderConfig) + + +class TestAppFromConfig: + def load_config_data(self, yaml_path): + with open(yaml_path, "r") as file: + return yaml.safe_load(file) + + def test_from_chroma_config(self): + yaml_path = "embedchain/yaml/chroma.yaml" + config_data = self.load_config_data(yaml_path) + + app = App.from_config(yaml_path) + + # Check if the App instance and its components were created correctly + assert isinstance(app, App) + + # Validate the AppConfig values + assert app.config.id == config_data["app"]["config"]["id"] + assert app.config.collection_name == config_data["app"]["config"]["collection_name"] + # Even though not present in the config, the default value is used + assert app.config.collect_metrics is True + + # Validate the LLM config values + llm_config = config_data["llm"]["config"] + assert app.llm.config.temperature == llm_config["temperature"] + assert app.llm.config.max_tokens == llm_config["max_tokens"] + assert app.llm.config.top_p == llm_config["top_p"] + assert app.llm.config.stream == llm_config["stream"] + + # Validate the VectorDB config values + db_config = config_data["vectordb"]["config"] + assert app.db.config.collection_name == db_config["collection_name"] + assert app.db.config.dir == db_config["dir"] + assert app.db.config.allow_reset == db_config["allow_reset"] + + # Validate the Embedder config values + embedder_config = config_data["embedder"]["config"] + assert app.embedder.config.model == embedder_config["model"] + assert app.embedder.config.deployment_name == embedder_config["deployment_name"] + + def test_from_opensource_config(self): + yaml_path = "embedchain/yaml/opensource.yaml" + config_data = self.load_config_data(yaml_path) + + app = App.from_config(yaml_path) + + # Check if the App instance and its components were created correctly + assert isinstance(app, App) + + # Validate the AppConfig values + assert app.config.id == config_data["app"]["config"]["id"] + assert app.config.collection_name == config_data["app"]["config"]["collection_name"] + assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"] + + # Validate the LLM config values + llm_config = config_data["llm"]["config"] + assert app.llm.config.temperature == llm_config["temperature"] + assert app.llm.config.max_tokens == llm_config["max_tokens"] + assert app.llm.config.top_p == llm_config["top_p"] + assert app.llm.config.stream == llm_config["stream"] + + # Validate the VectorDB config values + db_config = config_data["vectordb"]["config"] + assert app.db.config.collection_name == db_config["collection_name"] + assert app.db.config.dir == db_config["dir"] + assert app.db.config.allow_reset == db_config["allow_reset"] + + # Validate the Embedder config values + embedder_config = config_data["embedder"]["config"] + assert app.embedder.config.model == embedder_config["model"] + assert app.embedder.config.deployment_name == embedder_config["deployment_name"] diff --git a/tests/bots/test_base.py b/tests/bots/test_base.py index 050de6e6..d15f06cf 100644 --- a/tests/bots/test_base.py +++ b/tests/bots/test_base.py @@ -1,9 +1,11 @@ import os -import pytest -from embedchain.config import AddConfig, BaseLlmConfig -from embedchain.bots.base import BaseBot from unittest.mock import patch +import pytest + +from embedchain.bots.base import BaseBot +from embedchain.config import AddConfig, BaseLlmConfig + @pytest.fixture def base_bot(): diff --git a/tests/chunkers/test_base_chunker.py b/tests/chunkers/test_base_chunker.py index 2a9deffb..6f89cd03 100644 --- a/tests/chunkers/test_base_chunker.py +++ b/tests/chunkers/test_base_chunker.py @@ -1,6 +1,8 @@ import hashlib -import pytest from unittest.mock import MagicMock + +import pytest + from embedchain.chunkers.base_chunker import BaseChunker from embedchain.models.data_type import DataType diff --git a/tests/embedchain/test_embedchain.py b/tests/embedchain/test_embedchain.py index 1bd97e31..3c5ffc7e 100644 --- a/tests/embedchain/test_embedchain.py +++ b/tests/embedchain/test_embedchain.py @@ -41,7 +41,8 @@ class TestChromaDbHostsLoglevel(unittest.TestCase): Test if the `App` instance is correctly reconstructed after a reset. """ config = AppConfig(log_level="DEBUG", collect_metrics=False) - app = App(config=config, chromadb_config=ChromaDbConfig(chroma_settings={"allow_reset": True})) + chroma_config = {"allow_reset": True} + app = App(config=config, db_config=ChromaDbConfig(**chroma_config)) app.reset() # Make sure the client is still healthy diff --git a/tests/embedder/test_embedder.py b/tests/embedder/test_embedder.py index e83f9770..bf03cab4 100644 --- a/tests/embedder/test_embedder.py +++ b/tests/embedder/test_embedder.py @@ -1,9 +1,11 @@ -import pytest from unittest.mock import MagicMock -from embedchain.embedder.base import BaseEmbedder -from embedchain.config.embedder.base import BaseEmbedderConfig + +import pytest from chromadb.api.types import Documents, Embeddings +from embedchain.config.embedder.base import BaseEmbedderConfig +from embedchain.embedder.base import BaseEmbedder + @pytest.fixture def base_embedder(): diff --git a/tests/llm/test_antrophic.py b/tests/llm/test_antrophic.py index d2489a05..2bb74d35 100644 --- a/tests/llm/test_antrophic.py +++ b/tests/llm/test_antrophic.py @@ -1,62 +1,63 @@ -import pytest from unittest.mock import MagicMock, patch -from embedchain.llm.antrophic import AntrophicLlm -from embedchain.config import BaseLlmConfig +import pytest from langchain.schema import HumanMessage, SystemMessage +from embedchain.config import BaseLlmConfig +from embedchain.llm.anthropic import AnthropicLlm + @pytest.fixture -def antrophic_llm(): +def anthropic_llm(): config = BaseLlmConfig(temperature=0.5, model="gpt2") - return AntrophicLlm(config) + return AnthropicLlm(config) -def test_get_llm_model_answer(antrophic_llm): - with patch.object(AntrophicLlm, "_get_answer", return_value="Test Response") as mock_method: +def test_get_llm_model_answer(anthropic_llm): + with patch.object(AnthropicLlm, "_get_answer", return_value="Test Response") as mock_method: prompt = "Test Prompt" - response = antrophic_llm.get_llm_model_answer(prompt) + response = anthropic_llm.get_llm_model_answer(prompt) assert response == "Test Response" - mock_method.assert_called_once_with(prompt=prompt, config=antrophic_llm.config) + mock_method.assert_called_once_with(prompt=prompt, config=anthropic_llm.config) -def test_get_answer(antrophic_llm): +def test_get_answer(anthropic_llm): with patch("langchain.chat_models.ChatAnthropic") as mock_chat: mock_chat_instance = mock_chat.return_value mock_chat_instance.return_value = MagicMock(content="Test Response") prompt = "Test Prompt" - response = antrophic_llm._get_answer(prompt, antrophic_llm.config) + response = anthropic_llm._get_answer(prompt, anthropic_llm.config) assert response == "Test Response" mock_chat.assert_called_once_with( - temperature=antrophic_llm.config.temperature, model=antrophic_llm.config.model + temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model ) mock_chat_instance.assert_called_once_with( - antrophic_llm._get_messages(prompt, system_prompt=antrophic_llm.config.system_prompt) + anthropic_llm._get_messages(prompt, system_prompt=anthropic_llm.config.system_prompt) ) -def test_get_messages(antrophic_llm): +def test_get_messages(anthropic_llm): prompt = "Test Prompt" system_prompt = "Test System Prompt" - messages = antrophic_llm._get_messages(prompt, system_prompt) + messages = anthropic_llm._get_messages(prompt, system_prompt) assert messages == [ SystemMessage(content="Test System Prompt", additional_kwargs={}), HumanMessage(content="Test Prompt", additional_kwargs={}, example=False), ] -def test_get_answer_max_tokens_is_provided(antrophic_llm, caplog): +def test_get_answer_max_tokens_is_provided(anthropic_llm, caplog): with patch("langchain.chat_models.ChatAnthropic") as mock_chat: mock_chat_instance = mock_chat.return_value mock_chat_instance.return_value = MagicMock(content="Test Response") prompt = "Test Prompt" - config = antrophic_llm.config + config = anthropic_llm.config config.max_tokens = 500 - response = antrophic_llm._get_answer(prompt, config) + response = anthropic_llm._get_answer(prompt, config) assert response == "Test Response" mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model) diff --git a/tests/llm/test_azure_openai.py b/tests/llm/test_azure_openai.py index 7f2c7614..9ab3479d 100644 --- a/tests/llm/test_azure_openai.py +++ b/tests/llm/test_azure_openai.py @@ -1,9 +1,11 @@ -import pytest from unittest.mock import MagicMock, patch -from embedchain.llm.azure_openai import AzureOpenAILlm -from embedchain.config import BaseLlmConfig + +import pytest from langchain.schema import HumanMessage, SystemMessage +from embedchain.config import BaseLlmConfig +from embedchain.llm.azure_openai import AzureOpenAILlm + @pytest.fixture def azure_openai_llm(): diff --git a/tests/llm/test_hugging_face_hub.py b/tests/llm/test_hugging_face_hub.py index 63b1bfed..d4d9d64d 100644 --- a/tests/llm/test_hugging_face_hub.py +++ b/tests/llm/test_hugging_face_hub.py @@ -1,7 +1,7 @@ import importlib import os import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch from embedchain.config import BaseLlmConfig from embedchain.llm.hugging_face_hub import HuggingFaceHubLlm diff --git a/tests/llm/test_vertex_ai.py b/tests/llm/test_vertex_ai.py index 952d5c39..9941523d 100644 --- a/tests/llm/test_vertex_ai.py +++ b/tests/llm/test_vertex_ai.py @@ -1,9 +1,11 @@ -import pytest from unittest.mock import MagicMock, patch -from embedchain.llm.vertex_ai import VertexAiLlm -from embedchain.config import BaseLlmConfig + +import pytest from langchain.schema import HumanMessage, SystemMessage +from embedchain.config import BaseLlmConfig +from embedchain.llm.vertex_ai import VertexAiLlm + @pytest.fixture def vertexai_llm(): diff --git a/tests/loaders/test_docs_site.py b/tests/loaders/test_docs_site.py index e27bd1bf..31d03f67 100644 --- a/tests/loaders/test_docs_site.py +++ b/tests/loaders/test_docs_site.py @@ -1,7 +1,9 @@ import hashlib -import pytest from unittest.mock import Mock, patch + +import pytest from requests import Response + from embedchain.loaders.docs_site_loader import DocsSiteLoader diff --git a/tests/loaders/test_docx_file.py b/tests/loaders/test_docx_file.py index 6b3bb193..b7deffcb 100644 --- a/tests/loaders/test_docx_file.py +++ b/tests/loaders/test_docx_file.py @@ -1,6 +1,8 @@ import hashlib -import pytest from unittest.mock import MagicMock, patch + +import pytest + from embedchain.loaders.docx_file import DocxFileLoader diff --git a/tests/loaders/test_local_qna_pair.py b/tests/loaders/test_local_qna_pair.py index 29447d19..5bdfd2ca 100644 --- a/tests/loaders/test_local_qna_pair.py +++ b/tests/loaders/test_local_qna_pair.py @@ -1,5 +1,7 @@ import hashlib + import pytest + from embedchain.loaders.local_qna_pair import LocalQnaPairLoader diff --git a/tests/loaders/test_local_text.py b/tests/loaders/test_local_text.py index 7d350ea5..58b6ec8f 100644 --- a/tests/loaders/test_local_text.py +++ b/tests/loaders/test_local_text.py @@ -1,5 +1,7 @@ import hashlib + import pytest + from embedchain.loaders.local_text import LocalTextLoader diff --git a/tests/loaders/test_mdx.py b/tests/loaders/test_mdx.py index 960eb6e5..d4826209 100644 --- a/tests/loaders/test_mdx.py +++ b/tests/loaders/test_mdx.py @@ -1,6 +1,8 @@ import hashlib +from unittest.mock import mock_open, patch + import pytest -from unittest.mock import patch, mock_open + from embedchain.loaders.mdx import MdxLoader diff --git a/tests/loaders/test_notion.py b/tests/loaders/test_notion.py index f4d30327..ca5f6d21 100644 --- a/tests/loaders/test_notion.py +++ b/tests/loaders/test_notion.py @@ -1,7 +1,9 @@ import hashlib import os -import pytest from unittest.mock import Mock, patch + +import pytest + from embedchain.loaders.notion import NotionLoader diff --git a/tests/loaders/test_web_page.py b/tests/loaders/test_web_page.py index 61b9031b..cdaf0944 100644 --- a/tests/loaders/test_web_page.py +++ b/tests/loaders/test_web_page.py @@ -1,6 +1,8 @@ import hashlib -import pytest from unittest.mock import Mock, patch + +import pytest + from embedchain.loaders.web_page import WebPageLoader diff --git a/tests/loaders/test_youtube_video.py b/tests/loaders/test_youtube_video.py index cc70d779..d4733161 100644 --- a/tests/loaders/test_youtube_video.py +++ b/tests/loaders/test_youtube_video.py @@ -1,6 +1,8 @@ import hashlib -import pytest from unittest.mock import MagicMock, Mock, patch + +import pytest + from embedchain.loaders.youtube_video import YoutubeVideoLoader diff --git a/tests/test_factory.py b/tests/test_factory.py new file mode 100644 index 00000000..5c6e81ed --- /dev/null +++ b/tests/test_factory.py @@ -0,0 +1,61 @@ +import pytest + +import embedchain +import embedchain.embedder.gpt4all +import embedchain.embedder.huggingface +import embedchain.embedder.openai +import embedchain.embedder.vertexai +import embedchain.llm.anthropic +import embedchain.llm.openai +import embedchain.vectordb.chroma +import embedchain.vectordb.elasticsearch +import embedchain.vectordb.opensearch +from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory + + +class TestFactories: + @pytest.mark.parametrize( + "provider_name, config_data, expected_class", + [ + ("openai", {}, embedchain.llm.openai.OpenAILlm), + ("anthropic", {}, embedchain.llm.anthropic.AnthropicLlm), + ], + ) + def test_llm_factory_create(self, provider_name, config_data, expected_class): + llm_instance = LlmFactory.create(provider_name, config_data) + assert isinstance(llm_instance, expected_class) + + @pytest.mark.parametrize( + "provider_name, config_data, expected_class", + [ + ("gpt4all", {}, embedchain.embedder.gpt4all.GPT4AllEmbedder), + ( + "huggingface", + {"model": "sentence-transformers/all-mpnet-base-v2"}, + embedchain.embedder.huggingface.HuggingFaceEmbedder, + ), + ("vertexai", {"model": "textembedding-gecko"}, embedchain.embedder.vertexai.VertexAiEmbedder), + ("openai", {}, embedchain.embedder.openai.OpenAIEmbedder), + ], + ) + def test_embedder_factory_create(self, mocker, provider_name, config_data, expected_class): + mocker.patch("embedchain.embedder.vertexai.VertexAiEmbedder", autospec=True) + embedder_instance = EmbedderFactory.create(provider_name, config_data) + assert isinstance(embedder_instance, expected_class) + + @pytest.mark.parametrize( + "provider_name, config_data, expected_class", + [ + ("chroma", {}, embedchain.vectordb.chroma.ChromaDB), + ( + "opensearch", + {"opensearch_url": "http://localhost:9200", "http_auth": ("admin", "admin")}, + embedchain.vectordb.opensearch.OpenSearchDB, + ), + ("elasticsearch", {"es_url": "http://localhost:9200"}, embedchain.vectordb.elasticsearch.ElasticsearchDB), + ], + ) + def test_vectordb_factory_create(self, mocker, provider_name, config_data, expected_class): + mocker.patch("embedchain.vectordb.opensearch.OpenSearchDB", autospec=True) + vectordb_instance = VectorDBFactory.create(provider_name, config_data) + assert isinstance(vectordb_instance, expected_class) diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index 3be89d49..0472ec0f 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -28,19 +28,25 @@ class TestChromaDbHosts(unittest.TestCase): host = "test-host" port = "1234" - chroma_auth_settings = { - "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider", - "chroma_client_auth_credentials": "admin:admin", + chroma_config = { + "host": host, + "port": port, + "chroma_settings": { + "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider", + "chroma_client_auth_credentials": "admin:admin", + }, } - config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings) + config = ChromaDbConfig(**chroma_config) db = ChromaDB(config=config) settings = db.client.get_settings() self.assertEqual(settings.chroma_server_host, host) self.assertEqual(settings.chroma_server_http_port, port) - self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"]) self.assertEqual( - settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"] + settings.chroma_client_auth_provider, chroma_config["chroma_settings"]["chroma_client_auth_provider"] + ) + self.assertEqual( + settings.chroma_client_auth_credentials, chroma_config["chroma_settings"]["chroma_client_auth_credentials"] ) @@ -55,9 +61,9 @@ class TestChromaDbHostsInit(unittest.TestCase): port = "1234" config = AppConfig(collect_metrics=False) - chromadb_config = ChromaDbConfig(host=host, port=port) + db_config = ChromaDbConfig(host=host, port=port) - _app = App(config, chromadb_config=chromadb_config) + _app = App(config, db_config=db_config) called_settings: Settings = mock_client.call_args[0][0] @@ -95,7 +101,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase): class TestChromaDbDuplicateHandling: chroma_config = ChromaDbConfig(allow_reset=True) app_config = AppConfig(collection_name=False, collect_metrics=False) - app_with_settings = App(config=app_config, chromadb_config=chroma_config) + app_with_settings = App(config=app_config, db_config=chroma_config) def test_duplicates_throw_warning(self, caplog): """ @@ -131,7 +137,7 @@ class TestChromaDbDuplicateHandling: class TestChromaDbCollection(unittest.TestCase): chroma_config = ChromaDbConfig(allow_reset=True) app_config = AppConfig(collection_name=False, collect_metrics=False) - app_with_settings = App(config=app_config, chromadb_config=chroma_config) + app_with_settings = App(config=app_config, db_config=chroma_config) def test_init_with_default_collection(self): """ @@ -296,7 +302,7 @@ class TestChromaDbCollection(unittest.TestCase): # Create four apps. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last. - app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config) + app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), db_config=self.chroma_config) app1.set_collection_name("one_collection") app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False)) app2.set_collection_name("one_collection") diff --git a/tests/vectordb/test_zilliz_db.py b/tests/vectordb/test_zilliz_db.py index b4149d42..6dca78f7 100644 --- a/tests/vectordb/test_zilliz_db.py +++ b/tests/vectordb/test_zilliz_db.py @@ -1,14 +1,15 @@ # ruff: noqa: E501 import os +from unittest import mock +from unittest.mock import Mock, patch import pytest -from unittest import mock -from unittest.mock import patch, Mock from embedchain.config import ZillizDBConfig from embedchain.vectordb.zilliz import ZillizVectorDB + # to run tests, provide the URI and TOKEN in .env file class TestZillizVectorDBConfig: @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"}) @@ -51,6 +52,7 @@ class TestZillizVectorDBConfig: with pytest.raises(AttributeError): ZillizDBConfig() + class TestZillizVectorDB: @pytest.fixture @mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"}) @@ -147,7 +149,7 @@ class TestZillizDBCollection: zilliz_db = ZillizVectorDB(config=mock_config) # Add a 'embedder' attribute to the ZillizVectorDB instance for testing - zilliz_db.embedder = mock_embedder # Mock the 'collection' object + zilliz_db.embedder = mock_embedder # Mock the 'collection' object # Add a 'collection' attribute to the ZillizVectorDB instance for testing zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object