[feat] Refactor VectorDB class hierarchy for flexibility
This commit is contained in:
@@ -1,8 +1,6 @@
|
|||||||
import chromadb
|
|
||||||
import openai
|
import openai
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from chromadb.utils import embedding_functions
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
@@ -21,20 +19,17 @@ embeddings = OpenAIEmbeddings()
|
|||||||
ABS_PATH = os.getcwd()
|
ABS_PATH = os.getcwd()
|
||||||
DB_DIR = os.path.join(ABS_PATH, "db")
|
DB_DIR = os.path.join(ABS_PATH, "db")
|
||||||
|
|
||||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
|
||||||
model_name="text-embedding-ada-002"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbedChain:
|
class EmbedChain:
|
||||||
def __init__(self):
|
def __init__(self, db):
|
||||||
"""
|
"""
|
||||||
Initializes the EmbedChain instance, sets up a ChromaDB client and
|
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||||
creates a ChromaDB collection.
|
creates a collection.
|
||||||
|
|
||||||
|
:param db: The instance of the VectorDB subclass.
|
||||||
"""
|
"""
|
||||||
self.chromadb_client = self._get_or_create_db()
|
self.db_client = db.client
|
||||||
self.collection = self._get_or_create_collection()
|
self.collection = db.collection
|
||||||
self.user_asks = []
|
self.user_asks = []
|
||||||
|
|
||||||
def _get_loader(self, data_type):
|
def _get_loader(self, data_type):
|
||||||
@@ -87,29 +82,6 @@ class EmbedChain:
|
|||||||
self.user_asks.append([data_type, url])
|
self.user_asks.append([data_type, url])
|
||||||
self.load_and_embed(loader, chunker, url)
|
self.load_and_embed(loader, chunker, url)
|
||||||
|
|
||||||
def _get_or_create_db(self):
|
|
||||||
"""
|
|
||||||
Returns a ChromaDB client, creates a new one if needed.
|
|
||||||
|
|
||||||
:return: The ChromaDB client.
|
|
||||||
"""
|
|
||||||
client_settings = chromadb.config.Settings(
|
|
||||||
chroma_db_impl="duckdb+parquet",
|
|
||||||
persist_directory=DB_DIR,
|
|
||||||
anonymized_telemetry=False
|
|
||||||
)
|
|
||||||
return chromadb.Client(client_settings)
|
|
||||||
|
|
||||||
def _get_or_create_collection(self):
|
|
||||||
"""
|
|
||||||
Returns a ChromaDB collection, creates a new one if needed.
|
|
||||||
|
|
||||||
:return: The ChromaDB collection.
|
|
||||||
"""
|
|
||||||
return self.chromadb_client.get_or_create_collection(
|
|
||||||
'embedchain_store', embedding_function=openai_ef,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_and_embed(self, loader, chunker, url):
|
def load_and_embed(self, loader, chunker, url):
|
||||||
"""
|
"""
|
||||||
Loads the data from the given URL, chunks it, and adds it to the database.
|
Loads the data from the given URL, chunks it, and adds it to the database.
|
||||||
|
|||||||
10
embedchain/vectordb/base_vector_db.py
Normal file
10
embedchain/vectordb/base_vector_db.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
class BaseVectorDB:
|
||||||
|
def __init__(self):
|
||||||
|
self.client = self._get_or_create_db()
|
||||||
|
self.collection = self._get_or_create_collection()
|
||||||
|
|
||||||
|
def _get_or_create_db(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_or_create_collection(self):
|
||||||
|
raise NotImplementedError
|
||||||
26
embedchain/vectordb/chroma_db.py
Normal file
26
embedchain/vectordb/chroma_db.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import os
|
||||||
|
import chromadb
|
||||||
|
from base_vector_db import BaseVectorDB
|
||||||
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
|
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
model_name="text-embedding-ada-002"
|
||||||
|
)
|
||||||
|
|
||||||
|
class ChromaDB(BaseVectorDB):
|
||||||
|
def __init__(self, db_dir):
|
||||||
|
self.client_settings = chromadb.config.Settings(
|
||||||
|
chroma_db_impl="duckdb+parquet",
|
||||||
|
persist_directory=db_dir,
|
||||||
|
anonymized_telemetry=False
|
||||||
|
)
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _get_or_create_db(self):
|
||||||
|
return chromadb.Client(self.client_settings)
|
||||||
|
|
||||||
|
def _get_or_create_collection(self):
|
||||||
|
return self.client.get_or_create_collection(
|
||||||
|
'embedchain_store', embedding_function=openai_ef,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user