From 85a6a0c1612e783f5cc3e25e5fbc13fab3e57549 Mon Sep 17 00:00:00 2001 From: Sayo Date: Thu, 22 Jun 2023 12:15:26 +0800 Subject: [PATCH] [feat] Refactor VectorDB class hierarchy for flexibility --- embedchain/embedchain.py | 42 +++++---------------------- embedchain/vectordb/base_vector_db.py | 10 +++++++ embedchain/vectordb/chroma_db.py | 26 +++++++++++++++++ 3 files changed, 43 insertions(+), 35 deletions(-) create mode 100644 embedchain/vectordb/base_vector_db.py create mode 100644 embedchain/vectordb/chroma_db.py diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 62fec8ea..cd3f5c33 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -1,8 +1,6 @@ -import chromadb import openai import os -from chromadb.utils import embedding_functions from dotenv import load_dotenv from langchain.docstore.document import Document from langchain.embeddings.openai import OpenAIEmbeddings @@ -21,20 +19,17 @@ embeddings = OpenAIEmbeddings() ABS_PATH = os.getcwd() DB_DIR = os.path.join(ABS_PATH, "db") -openai_ef = embedding_functions.OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), - model_name="text-embedding-ada-002" -) - class EmbedChain: - def __init__(self): + def __init__(self, db): """ - Initializes the EmbedChain instance, sets up a ChromaDB client and - creates a ChromaDB collection. + Initializes the EmbedChain instance, sets up a vector DB client and + creates a collection. + + :param db: The instance of the VectorDB subclass. """ - self.chromadb_client = self._get_or_create_db() - self.collection = self._get_or_create_collection() + self.db_client = db.client + self.collection = db.collection self.user_asks = [] def _get_loader(self, data_type): @@ -87,29 +82,6 @@ class EmbedChain: self.user_asks.append([data_type, url]) self.load_and_embed(loader, chunker, url) - def _get_or_create_db(self): - """ - Returns a ChromaDB client, creates a new one if needed. - - :return: The ChromaDB client. - """ - client_settings = chromadb.config.Settings( - chroma_db_impl="duckdb+parquet", - persist_directory=DB_DIR, - anonymized_telemetry=False - ) - return chromadb.Client(client_settings) - - def _get_or_create_collection(self): - """ - Returns a ChromaDB collection, creates a new one if needed. - - :return: The ChromaDB collection. - """ - return self.chromadb_client.get_or_create_collection( - 'embedchain_store', embedding_function=openai_ef, - ) - def load_and_embed(self, loader, chunker, url): """ Loads the data from the given URL, chunks it, and adds it to the database. diff --git a/embedchain/vectordb/base_vector_db.py b/embedchain/vectordb/base_vector_db.py new file mode 100644 index 00000000..190646f3 --- /dev/null +++ b/embedchain/vectordb/base_vector_db.py @@ -0,0 +1,10 @@ +class BaseVectorDB: + def __init__(self): + self.client = self._get_or_create_db() + self.collection = self._get_or_create_collection() + + def _get_or_create_db(self): + raise NotImplementedError + + def _get_or_create_collection(self): + raise NotImplementedError \ No newline at end of file diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py new file mode 100644 index 00000000..34ce08cc --- /dev/null +++ b/embedchain/vectordb/chroma_db.py @@ -0,0 +1,26 @@ +import os +import chromadb +from base_vector_db import BaseVectorDB +from chromadb.utils import embedding_functions + +openai_ef = embedding_functions.OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), + model_name="text-embedding-ada-002" +) + +class ChromaDB(BaseVectorDB): + def __init__(self, db_dir): + self.client_settings = chromadb.config.Settings( + chroma_db_impl="duckdb+parquet", + persist_directory=db_dir, + anonymized_telemetry=False + ) + super().__init__() + + def _get_or_create_db(self): + return chromadb.Client(self.client_settings) + + def _get_or_create_collection(self): + return self.client.get_or_create_collection( + 'embedchain_store', embedding_function=openai_ef, + ) \ No newline at end of file