From 200f11a0e0590c554a0379e40c94e82a0da7ce7c Mon Sep 17 00:00:00 2001 From: Taranjeet Singh Date: Wed, 5 Jul 2023 23:03:15 +0530 Subject: [PATCH] fix: Fix dependency of openai env variables for OpenSourceApp (#144) This commit fixes dependency of initializing openai env variables for OpenSourceApp. --- embedchain/embedchain.py | 16 ++++++++-------- embedchain/vectordb/chroma_db.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 4606cb08..65ec7361 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -19,12 +19,6 @@ from embedchain.chunkers.qna_pair import QnaPairChunker from embedchain.chunkers.text import TextChunker from embedchain.vectordb.chroma_db import ChromaDB -openai_ef = embedding_functions.OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), - organization_id=os.getenv("OPENAI_ORGANIZATION"), - model_name="text-embedding-ada-002" -) -sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2") gpt4all_model = None @@ -238,7 +232,11 @@ class App(EmbedChain): def __int__(self, db=None, ef=None): if ef is None: - ef = openai_ef + ef = embedding_functions.OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), + organization_id=os.getenv("OPENAI_ORGANIZATION"), + model_name="text-embedding-ada-002" + ) super().__init__(db, ef) def get_llm_model_answer(self, prompt): @@ -270,7 +268,9 @@ class OpenSourceApp(EmbedChain): def __init__(self, db=None, ef=None): print("Loading open source embedding model. This may take some time...") if ef is None: - ef = sentence_transformer_ef + ef = embedding_functions.SentenceTransformerEmbeddingFunction( + model_name="all-MiniLM-L6-v2" + ) print("Successfully loaded open source embedding model.") super().__init__(db, ef) diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 96166f83..6cfa7db1 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -5,15 +5,17 @@ from chromadb.utils import embedding_functions from embedchain.vectordb.base_vector_db import BaseVectorDB -openai_ef = embedding_functions.OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), - organization_id=os.getenv("OPENAI_ORGANIZATION"), - model_name="text-embedding-ada-002" -) class ChromaDB(BaseVectorDB): def __init__(self, db_dir=None, ef=None): - self.ef = ef if ef is not None else openai_ef + if ef: + self.ef = ef + else: + self.ef = embedding_functions.OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), + organization_id=os.getenv("OPENAI_ORGANIZATION"), + model_name="text-embedding-ada-002" + ) if db_dir is None: db_dir = "db" self.client_settings = chromadb.config.Settings(