diff --git a/embedchain/config/InitConfig.py b/embedchain/config/InitConfig.py index f0614956..a6a91f9b 100644 --- a/embedchain/config/InitConfig.py +++ b/embedchain/config/InitConfig.py @@ -1,6 +1,8 @@ import logging import os +from chromadb.utils import embedding_functions + from embedchain.config.BaseConfig import BaseConfig @@ -37,11 +39,40 @@ class InitConfig(BaseConfig): else: self.db = db + self.ef = ef + self.db = db return def _set_embedding_function(self, ef): self.ef = ef return + + def _set_embedding_function_to_default(self): + """ + Sets embedding function to default (`text-embedding-ada-002`). + + :raises ValueError: If the template is not valid as template should contain $context and $query + """ + if os.getenv("OPENAI_API_KEY") is None or os.getenv("OPENAI_ORGANIZATION") is None: + raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") + self.ef = embedding_functions.OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), + organization_id=os.getenv("OPENAI_ORGANIZATION"), + model_name="text-embedding-ada-002" + ) + return + + def _set_db(self, db): + if db: + self.db = db + return + + def _set_db_to_default(self): + """ + Sets database to default (`ChromaDb`). + """ + from embedchain.vectordb.chroma_db import ChromaDB + self.db = ChromaDB(ef=self.ef) def _setup_logging(self, debug_level): level = logging.WARNING # Default level diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 0e7f60ee..bf548618 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -297,6 +297,13 @@ class App(EmbedChain): """ if config is None: config = InitConfig() + + if not config.ef: + config._set_embedding_function_to_default() + + if not config.db: + config._set_db_to_default() + super().__init__(config) def get_llm_model_answer(self, prompt, config: ChatConfig): @@ -345,17 +352,17 @@ class OpenSourceApp(EmbedChain): "Loading open source embedding model. This may take some time..." ) # noqa:E501 if not config: - config = InitConfig( - ef=embedding_functions.SentenceTransformerEmbeddingFunction( - model_name="all-MiniLM-L6-v2" - ) - ) - elif not config.ef: + config = InitConfig() + + if not config.ef: config._set_embedding_function( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name="all-MiniLM-L6-v2" - ) - ) + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name="all-MiniLM-L6-v2" + )) + + if not config.db: + config._set_db_to_default() + print("Successfully loaded open source embedding model.") super().__init__(config)