diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 4606cb08..65ec7361 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -19,12 +19,6 @@ from embedchain.chunkers.qna_pair import QnaPairChunker from embedchain.chunkers.text import TextChunker from embedchain.vectordb.chroma_db import ChromaDB -openai_ef = embedding_functions.OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), - organization_id=os.getenv("OPENAI_ORGANIZATION"), - model_name="text-embedding-ada-002" -) -sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2") gpt4all_model = None @@ -238,7 +232,11 @@ class App(EmbedChain): def __int__(self, db=None, ef=None): if ef is None: - ef = openai_ef + ef = embedding_functions.OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), + organization_id=os.getenv("OPENAI_ORGANIZATION"), + model_name="text-embedding-ada-002" + ) super().__init__(db, ef) def get_llm_model_answer(self, prompt): @@ -270,7 +268,9 @@ class OpenSourceApp(EmbedChain): def __init__(self, db=None, ef=None): print("Loading open source embedding model. This may take some time...") if ef is None: - ef = sentence_transformer_ef + ef = embedding_functions.SentenceTransformerEmbeddingFunction( + model_name="all-MiniLM-L6-v2" + ) print("Successfully loaded open source embedding model.") super().__init__(db, ef) diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 96166f83..6cfa7db1 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -5,15 +5,17 @@ from chromadb.utils import embedding_functions from embedchain.vectordb.base_vector_db import BaseVectorDB -openai_ef = embedding_functions.OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), - organization_id=os.getenv("OPENAI_ORGANIZATION"), - model_name="text-embedding-ada-002" -) class ChromaDB(BaseVectorDB): def __init__(self, db_dir=None, ef=None): - self.ef = ef if ef is not None else openai_ef + if ef: + self.ef = ef + else: + self.ef = embedding_functions.OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), + organization_id=os.getenv("OPENAI_ORGANIZATION"), + model_name="text-embedding-ada-002" + ) if db_dir is None: db_dir = "db" self.client_settings = chromadb.config.Settings(