From 85a6a0c1612e783f5cc3e25e5fbc13fab3e57549 Mon Sep 17 00:00:00 2001 From: Sayo Date: Thu, 22 Jun 2023 12:15:26 +0800 Subject: [PATCH 1/2] [feat] Refactor VectorDB class hierarchy for flexibility --- embedchain/embedchain.py | 42 +++++---------------------- embedchain/vectordb/base_vector_db.py | 10 +++++++ embedchain/vectordb/chroma_db.py | 26 +++++++++++++++++ 3 files changed, 43 insertions(+), 35 deletions(-) create mode 100644 embedchain/vectordb/base_vector_db.py create mode 100644 embedchain/vectordb/chroma_db.py diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 62fec8ea..cd3f5c33 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -1,8 +1,6 @@ -import chromadb import openai import os -from chromadb.utils import embedding_functions from dotenv import load_dotenv from langchain.docstore.document import Document from langchain.embeddings.openai import OpenAIEmbeddings @@ -21,20 +19,17 @@ embeddings = OpenAIEmbeddings() ABS_PATH = os.getcwd() DB_DIR = os.path.join(ABS_PATH, "db") -openai_ef = embedding_functions.OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), - model_name="text-embedding-ada-002" -) - class EmbedChain: - def __init__(self): + def __init__(self, db): """ - Initializes the EmbedChain instance, sets up a ChromaDB client and - creates a ChromaDB collection. + Initializes the EmbedChain instance, sets up a vector DB client and + creates a collection. + + :param db: The instance of the VectorDB subclass. """ - self.chromadb_client = self._get_or_create_db() - self.collection = self._get_or_create_collection() + self.db_client = db.client + self.collection = db.collection self.user_asks = [] def _get_loader(self, data_type): @@ -87,29 +82,6 @@ class EmbedChain: self.user_asks.append([data_type, url]) self.load_and_embed(loader, chunker, url) - def _get_or_create_db(self): - """ - Returns a ChromaDB client, creates a new one if needed. - - :return: The ChromaDB client. - """ - client_settings = chromadb.config.Settings( - chroma_db_impl="duckdb+parquet", - persist_directory=DB_DIR, - anonymized_telemetry=False - ) - return chromadb.Client(client_settings) - - def _get_or_create_collection(self): - """ - Returns a ChromaDB collection, creates a new one if needed. - - :return: The ChromaDB collection. - """ - return self.chromadb_client.get_or_create_collection( - 'embedchain_store', embedding_function=openai_ef, - ) - def load_and_embed(self, loader, chunker, url): """ Loads the data from the given URL, chunks it, and adds it to the database. diff --git a/embedchain/vectordb/base_vector_db.py b/embedchain/vectordb/base_vector_db.py new file mode 100644 index 00000000..190646f3 --- /dev/null +++ b/embedchain/vectordb/base_vector_db.py @@ -0,0 +1,10 @@ +class BaseVectorDB: + def __init__(self): + self.client = self._get_or_create_db() + self.collection = self._get_or_create_collection() + + def _get_or_create_db(self): + raise NotImplementedError + + def _get_or_create_collection(self): + raise NotImplementedError \ No newline at end of file diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py new file mode 100644 index 00000000..34ce08cc --- /dev/null +++ b/embedchain/vectordb/chroma_db.py @@ -0,0 +1,26 @@ +import os +import chromadb +from base_vector_db import BaseVectorDB +from chromadb.utils import embedding_functions + +openai_ef = embedding_functions.OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), + model_name="text-embedding-ada-002" +) + +class ChromaDB(BaseVectorDB): + def __init__(self, db_dir): + self.client_settings = chromadb.config.Settings( + chroma_db_impl="duckdb+parquet", + persist_directory=db_dir, + anonymized_telemetry=False + ) + super().__init__() + + def _get_or_create_db(self): + return chromadb.Client(self.client_settings) + + def _get_or_create_collection(self): + return self.client.get_or_create_collection( + 'embedchain_store', embedding_function=openai_ef, + ) \ No newline at end of file From cff244b894f2c5df544b1fd6baeb741a3b64c276 Mon Sep 17 00:00:00 2001 From: Taranjeet Singh Date: Fri, 23 Jun 2023 11:51:31 +0530 Subject: [PATCH 2/2] Add default db loader, fix import This commits builds on DumoeDss's PR. It - adds a default db directory name. - adds a default db instance (Chroma). Both points offers flexibility for users who want to use default and users who want to customize Lastly, it fixes an import --- embedchain/embedchain.py | 5 ++++- embedchain/vectordb/chroma_db.py | 10 +++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index cd3f5c33..d6cd2f5f 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -11,6 +11,7 @@ from embedchain.loaders.web_page import WebPageLoader from embedchain.chunkers.youtube_video import YoutubeVideoChunker from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.web_page import WebPageChunker +from embedchain.vectordb.chroma_db import ChromaDB load_dotenv() @@ -21,13 +22,15 @@ DB_DIR = os.path.join(ABS_PATH, "db") class EmbedChain: - def __init__(self, db): + def __init__(self, db=None): """ Initializes the EmbedChain instance, sets up a vector DB client and creates a collection. :param db: The instance of the VectorDB subclass. """ + if db is None: + db = ChromaDB() self.db_client = db.client self.collection = db.collection self.user_asks = [] diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 34ce08cc..e4981394 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -1,15 +1,19 @@ -import os import chromadb -from base_vector_db import BaseVectorDB +import os + from chromadb.utils import embedding_functions +from embedchain.vectordb.base_vector_db import BaseVectorDB + openai_ef = embedding_functions.OpenAIEmbeddingFunction( api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-ada-002" ) class ChromaDB(BaseVectorDB): - def __init__(self, db_dir): + def __init__(self, db_dir=None): + if db_dir is None: + db_dir = "db" self.client_settings = chromadb.config.Settings( chroma_db_impl="duckdb+parquet", persist_directory=db_dir,