diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index cd3f5c33..d6cd2f5f 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -11,6 +11,7 @@ from embedchain.loaders.web_page import WebPageLoader from embedchain.chunkers.youtube_video import YoutubeVideoChunker from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.web_page import WebPageChunker +from embedchain.vectordb.chroma_db import ChromaDB load_dotenv() @@ -21,13 +22,15 @@ DB_DIR = os.path.join(ABS_PATH, "db") class EmbedChain: - def __init__(self, db): + def __init__(self, db=None): """ Initializes the EmbedChain instance, sets up a vector DB client and creates a collection. :param db: The instance of the VectorDB subclass. """ + if db is None: + db = ChromaDB() self.db_client = db.client self.collection = db.collection self.user_asks = [] diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 34ce08cc..e4981394 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -1,15 +1,19 @@ -import os import chromadb -from base_vector_db import BaseVectorDB +import os + from chromadb.utils import embedding_functions +from embedchain.vectordb.base_vector_db import BaseVectorDB + openai_ef = embedding_functions.OpenAIEmbeddingFunction( api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-ada-002" ) class ChromaDB(BaseVectorDB): - def __init__(self, db_dir): + def __init__(self, db_dir=None): + if db_dir is None: + db_dir = "db" self.client_settings = chromadb.config.Settings( chroma_db_impl="duckdb+parquet", persist_directory=db_dir,