From 9409e5605a3b51abb0fc127a3bd629859cd1ed03 Mon Sep 17 00:00:00 2001 From: cachho Date: Tue, 11 Jul 2023 14:04:57 +0200 Subject: [PATCH] chore: linting (#230) --- embedchain/config/InitConfig.py | 27 +++++++++++++++++---------- embedchain/embedchain.py | 13 +++++++------ setup.py | 8 +------- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/embedchain/config/InitConfig.py b/embedchain/config/InitConfig.py index a6a91f9b..5189e890 100644 --- a/embedchain/config/InitConfig.py +++ b/embedchain/config/InitConfig.py @@ -46,25 +46,31 @@ class InitConfig(BaseConfig): def _set_embedding_function(self, ef): self.ef = ef return - + def _set_embedding_function_to_default(self): """ Sets embedding function to default (`text-embedding-ada-002`). - :raises ValueError: If the template is not valid as template should contain $context and $query + :raises ValueError: If the template is not valid as template should contain + $context and $query """ - if os.getenv("OPENAI_API_KEY") is None or os.getenv("OPENAI_ORGANIZATION") is None: - raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") - self.ef = embedding_functions.OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), - organization_id=os.getenv("OPENAI_ORGANIZATION"), - model_name="text-embedding-ada-002" + if ( + os.getenv("OPENAI_API_KEY") is None + or os.getenv("OPENAI_ORGANIZATION") is None + ): + raise ValueError( + "OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided" # noqa:E501 ) + self.ef = embedding_functions.OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), + organization_id=os.getenv("OPENAI_ORGANIZATION"), + model_name="text-embedding-ada-002", + ) return - + def _set_db(self, db): if db: - self.db = db + self.db = db return def _set_db_to_default(self): @@ -72,6 +78,7 @@ class InitConfig(BaseConfig): Sets database to default (`ChromaDb`). """ from embedchain.vectordb.chroma_db import ChromaDB + self.db = ChromaDB(ef=self.ef) def _setup_logging(self, debug_level): diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 9efff18d..79b7239d 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -301,13 +301,13 @@ class App(EmbedChain): """ if config is None: config = InitConfig() - + if not config.ef: config._set_embedding_function_to_default() if not config.db: config._set_db_to_default() - + super().__init__(config) def get_llm_model_answer(self, prompt, config: ChatConfig): @@ -357,12 +357,13 @@ class OpenSourceApp(EmbedChain): ) # noqa:E501 if not config: config = InitConfig() - + if not config.ef: config._set_embedding_function( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name="all-MiniLM-L6-v2" - )) + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name="all-MiniLM-L6-v2" + ) + ) if not config.db: config._set_db_to_default() diff --git a/setup.py b/setup.py index 60c5524e..d4ce22e5 100644 --- a/setup.py +++ b/setup.py @@ -34,11 +34,5 @@ setuptools.setup( "docx2txt", "pydantic==1.10.8", ], - extras_require={ - "dev": [ - "black", - "ruff", - "isort" - ] - } + extras_require={"dev": ["black", "ruff", "isort"]}, )