[feat] Refactor VectorDB class hierarchy for flexibility

This commit is contained in:
Sayo
2023-06-22 12:15:26 +08:00
parent 973dc5434f
commit 85a6a0c161
3 changed files with 43 additions and 35 deletions

View File

@@ -1,8 +1,6 @@
import chromadb
import openai import openai
import os import os
from chromadb.utils import embedding_functions
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings
@@ -21,20 +19,17 @@ embeddings = OpenAIEmbeddings()
ABS_PATH = os.getcwd() ABS_PATH = os.getcwd()
DB_DIR = os.path.join(ABS_PATH, "db") DB_DIR = os.path.join(ABS_PATH, "db")
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-ada-002"
)
class EmbedChain: class EmbedChain:
def __init__(self): def __init__(self, db):
""" """
Initializes the EmbedChain instance, sets up a ChromaDB client and Initializes the EmbedChain instance, sets up a vector DB client and
creates a ChromaDB collection. creates a collection.
:param db: The instance of the VectorDB subclass.
""" """
self.chromadb_client = self._get_or_create_db() self.db_client = db.client
self.collection = self._get_or_create_collection() self.collection = db.collection
self.user_asks = [] self.user_asks = []
def _get_loader(self, data_type): def _get_loader(self, data_type):
@@ -87,29 +82,6 @@ class EmbedChain:
self.user_asks.append([data_type, url]) self.user_asks.append([data_type, url])
self.load_and_embed(loader, chunker, url) self.load_and_embed(loader, chunker, url)
def _get_or_create_db(self):
"""
Returns a ChromaDB client, creates a new one if needed.
:return: The ChromaDB client.
"""
client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=DB_DIR,
anonymized_telemetry=False
)
return chromadb.Client(client_settings)
def _get_or_create_collection(self):
"""
Returns a ChromaDB collection, creates a new one if needed.
:return: The ChromaDB collection.
"""
return self.chromadb_client.get_or_create_collection(
'embedchain_store', embedding_function=openai_ef,
)
def load_and_embed(self, loader, chunker, url): def load_and_embed(self, loader, chunker, url):
""" """
Loads the data from the given URL, chunks it, and adds it to the database. Loads the data from the given URL, chunks it, and adds it to the database.

View File

@@ -0,0 +1,10 @@
class BaseVectorDB:
def __init__(self):
self.client = self._get_or_create_db()
self.collection = self._get_or_create_collection()
def _get_or_create_db(self):
raise NotImplementedError
def _get_or_create_collection(self):
raise NotImplementedError

View File

@@ -0,0 +1,26 @@
import os
import chromadb
from base_vector_db import BaseVectorDB
from chromadb.utils import embedding_functions
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-ada-002"
)
class ChromaDB(BaseVectorDB):
def __init__(self, db_dir):
self.client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=db_dir,
anonymized_telemetry=False
)
super().__init__()
def _get_or_create_db(self):
return chromadb.Client(self.client_settings)
def _get_or_create_collection(self):
return self.client.get_or_create_collection(
'embedchain_store', embedding_function=openai_ef,
)