From cff244b894f2c5df544b1fd6baeb741a3b64c276 Mon Sep 17 00:00:00 2001 From: Taranjeet Singh Date: Fri, 23 Jun 2023 11:51:31 +0530 Subject: [PATCH] Add default db loader, fix import This commits builds on DumoeDss's PR. It - adds a default db directory name. - adds a default db instance (Chroma). Both points offers flexibility for users who want to use default and users who want to customize Lastly, it fixes an import --- embedchain/embedchain.py | 5 ++++- embedchain/vectordb/chroma_db.py | 10 +++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index cd3f5c33..d6cd2f5f 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -11,6 +11,7 @@ from embedchain.loaders.web_page import WebPageLoader from embedchain.chunkers.youtube_video import YoutubeVideoChunker from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.web_page import WebPageChunker +from embedchain.vectordb.chroma_db import ChromaDB load_dotenv() @@ -21,13 +22,15 @@ DB_DIR = os.path.join(ABS_PATH, "db") class EmbedChain: - def __init__(self, db): + def __init__(self, db=None): """ Initializes the EmbedChain instance, sets up a vector DB client and creates a collection. :param db: The instance of the VectorDB subclass. """ + if db is None: + db = ChromaDB() self.db_client = db.client self.collection = db.collection self.user_asks = [] diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 34ce08cc..e4981394 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -1,15 +1,19 @@ -import os import chromadb -from base_vector_db import BaseVectorDB +import os + from chromadb.utils import embedding_functions +from embedchain.vectordb.base_vector_db import BaseVectorDB + openai_ef = embedding_functions.OpenAIEmbeddingFunction( api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-ada-002" ) class ChromaDB(BaseVectorDB): - def __init__(self, db_dir): + def __init__(self, db_dir=None): + if db_dir is None: + db_dir = "db" self.client_settings = chromadb.config.Settings( chroma_db_impl="duckdb+parquet", persist_directory=db_dir,