diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 62fec8ea..d6cd2f5f 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -1,8 +1,6 @@ -import chromadb import openai import os -from chromadb.utils import embedding_functions from dotenv import load_dotenv from langchain.docstore.document import Document from langchain.embeddings.openai import OpenAIEmbeddings @@ -13,6 +11,7 @@ from embedchain.loaders.web_page import WebPageLoader from embedchain.chunkers.youtube_video import YoutubeVideoChunker from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.web_page import WebPageChunker +from embedchain.vectordb.chroma_db import ChromaDB load_dotenv() @@ -21,20 +20,19 @@ embeddings = OpenAIEmbeddings() ABS_PATH = os.getcwd() DB_DIR = os.path.join(ABS_PATH, "db") -openai_ef = embedding_functions.OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), - model_name="text-embedding-ada-002" -) - class EmbedChain: - def __init__(self): + def __init__(self, db=None): """ - Initializes the EmbedChain instance, sets up a ChromaDB client and - creates a ChromaDB collection. + Initializes the EmbedChain instance, sets up a vector DB client and + creates a collection. + + :param db: The instance of the VectorDB subclass. """ - self.chromadb_client = self._get_or_create_db() - self.collection = self._get_or_create_collection() + if db is None: + db = ChromaDB() + self.db_client = db.client + self.collection = db.collection self.user_asks = [] def _get_loader(self, data_type): @@ -87,29 +85,6 @@ class EmbedChain: self.user_asks.append([data_type, url]) self.load_and_embed(loader, chunker, url) - def _get_or_create_db(self): - """ - Returns a ChromaDB client, creates a new one if needed. - - :return: The ChromaDB client. - """ - client_settings = chromadb.config.Settings( - chroma_db_impl="duckdb+parquet", - persist_directory=DB_DIR, - anonymized_telemetry=False - ) - return chromadb.Client(client_settings) - - def _get_or_create_collection(self): - """ - Returns a ChromaDB collection, creates a new one if needed. - - :return: The ChromaDB collection. - """ - return self.chromadb_client.get_or_create_collection( - 'embedchain_store', embedding_function=openai_ef, - ) - def load_and_embed(self, loader, chunker, url): """ Loads the data from the given URL, chunks it, and adds it to the database. diff --git a/embedchain/vectordb/base_vector_db.py b/embedchain/vectordb/base_vector_db.py new file mode 100644 index 00000000..190646f3 --- /dev/null +++ b/embedchain/vectordb/base_vector_db.py @@ -0,0 +1,10 @@ +class BaseVectorDB: + def __init__(self): + self.client = self._get_or_create_db() + self.collection = self._get_or_create_collection() + + def _get_or_create_db(self): + raise NotImplementedError + + def _get_or_create_collection(self): + raise NotImplementedError \ No newline at end of file diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py new file mode 100644 index 00000000..e4981394 --- /dev/null +++ b/embedchain/vectordb/chroma_db.py @@ -0,0 +1,30 @@ +import chromadb +import os + +from chromadb.utils import embedding_functions + +from embedchain.vectordb.base_vector_db import BaseVectorDB + +openai_ef = embedding_functions.OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), + model_name="text-embedding-ada-002" +) + +class ChromaDB(BaseVectorDB): + def __init__(self, db_dir=None): + if db_dir is None: + db_dir = "db" + self.client_settings = chromadb.config.Settings( + chroma_db_impl="duckdb+parquet", + persist_directory=db_dir, + anonymized_telemetry=False + ) + super().__init__() + + def _get_or_create_db(self): + return chromadb.Client(self.client_settings) + + def _get_or_create_collection(self): + return self.client.get_or_create_collection( + 'embedchain_store', embedding_function=openai_ef, + ) \ No newline at end of file