Merge pull request #22 from DumoeDss/feature_add_other_vectordb

[feat] Refactor VectorDB class hierarchy for flexibility
This commit is contained in:
Taranjeet Singh
2023-06-23 12:08:39 +05:30
committed by GitHub
3 changed files with 50 additions and 35 deletions

View File

@@ -1,8 +1,6 @@
import chromadb
import openai import openai
import os import os
from chromadb.utils import embedding_functions
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings
@@ -13,6 +11,7 @@ from embedchain.loaders.web_page import WebPageLoader
from embedchain.chunkers.youtube_video import YoutubeVideoChunker from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.web_page import WebPageChunker from embedchain.chunkers.web_page import WebPageChunker
from embedchain.vectordb.chroma_db import ChromaDB
load_dotenv() load_dotenv()
@@ -21,20 +20,19 @@ embeddings = OpenAIEmbeddings()
ABS_PATH = os.getcwd() ABS_PATH = os.getcwd()
DB_DIR = os.path.join(ABS_PATH, "db") DB_DIR = os.path.join(ABS_PATH, "db")
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-ada-002"
)
class EmbedChain: class EmbedChain:
def __init__(self): def __init__(self, db=None):
""" """
Initializes the EmbedChain instance, sets up a ChromaDB client and Initializes the EmbedChain instance, sets up a vector DB client and
creates a ChromaDB collection. creates a collection.
:param db: The instance of the VectorDB subclass.
""" """
self.chromadb_client = self._get_or_create_db() if db is None:
self.collection = self._get_or_create_collection() db = ChromaDB()
self.db_client = db.client
self.collection = db.collection
self.user_asks = [] self.user_asks = []
def _get_loader(self, data_type): def _get_loader(self, data_type):
@@ -87,29 +85,6 @@ class EmbedChain:
self.user_asks.append([data_type, url]) self.user_asks.append([data_type, url])
self.load_and_embed(loader, chunker, url) self.load_and_embed(loader, chunker, url)
def _get_or_create_db(self):
"""
Returns a ChromaDB client, creates a new one if needed.
:return: The ChromaDB client.
"""
client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=DB_DIR,
anonymized_telemetry=False
)
return chromadb.Client(client_settings)
def _get_or_create_collection(self):
"""
Returns a ChromaDB collection, creates a new one if needed.
:return: The ChromaDB collection.
"""
return self.chromadb_client.get_or_create_collection(
'embedchain_store', embedding_function=openai_ef,
)
def load_and_embed(self, loader, chunker, url): def load_and_embed(self, loader, chunker, url):
""" """
Loads the data from the given URL, chunks it, and adds it to the database. Loads the data from the given URL, chunks it, and adds it to the database.

View File

@@ -0,0 +1,10 @@
class BaseVectorDB:
def __init__(self):
self.client = self._get_or_create_db()
self.collection = self._get_or_create_collection()
def _get_or_create_db(self):
raise NotImplementedError
def _get_or_create_collection(self):
raise NotImplementedError

View File

@@ -0,0 +1,30 @@
import chromadb
import os
from chromadb.utils import embedding_functions
from embedchain.vectordb.base_vector_db import BaseVectorDB
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-ada-002"
)
class ChromaDB(BaseVectorDB):
def __init__(self, db_dir=None):
if db_dir is None:
db_dir = "db"
self.client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=db_dir,
anonymized_telemetry=False
)
super().__init__()
def _get_or_create_db(self):
return chromadb.Client(self.client_settings)
def _get_or_create_collection(self):
return self.client.get_or_create_collection(
'embedchain_store', embedding_function=openai_ef,
)