Open source embedding and LLM models (#133)
* Add open source LLM model: gpt4all * Add open source embedding model: sentence transformers
This commit is contained in:
@@ -1 +1 @@
|
||||
from .embedchain import App
|
||||
from .embedchain import App, OpenSourceApp
|
||||
@@ -1,7 +1,9 @@
|
||||
import openai
|
||||
import os
|
||||
|
||||
from chromadb.utils import embedding_functions
|
||||
from dotenv import load_dotenv
|
||||
from gpt4all import GPT4All
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
@@ -17,16 +19,23 @@ from embedchain.chunkers.qna_pair import QnaPairChunker
|
||||
from embedchain.chunkers.text import TextChunker
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
load_dotenv()
|
||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
model_name="text-embedding-ada-002"
|
||||
)
|
||||
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
gpt4all_model = None
|
||||
|
||||
load_dotenv()
|
||||
|
||||
ABS_PATH = os.getcwd()
|
||||
DB_DIR = os.path.join(ABS_PATH, "db")
|
||||
|
||||
|
||||
class EmbedChain:
|
||||
def __init__(self, db=None):
|
||||
def __init__(self, db=None, ef=None):
|
||||
"""
|
||||
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||
creates a collection.
|
||||
@@ -34,7 +43,7 @@ class EmbedChain:
|
||||
:param db: The instance of the VectorDB subclass.
|
||||
"""
|
||||
if db is None:
|
||||
db = ChromaDB()
|
||||
db = ChromaDB(ef=ef)
|
||||
self.db_client = db.client
|
||||
self.collection = db.collection
|
||||
self.user_asks = []
|
||||
@@ -154,20 +163,9 @@ class EmbedChain:
|
||||
)
|
||||
]
|
||||
|
||||
def get_openai_answer(self, prompt):
|
||||
messages = []
|
||||
messages.append({
|
||||
"role": "user", "content": prompt
|
||||
})
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo-0613",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=1000,
|
||||
top_p=1,
|
||||
)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
raise NotImplementedError
|
||||
|
||||
def retrieve_from_database(self, input_query):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
@@ -186,7 +184,7 @@ class EmbedChain:
|
||||
else:
|
||||
content = ""
|
||||
return content
|
||||
|
||||
|
||||
def generate_prompt(self, input_query, context):
|
||||
"""
|
||||
Generates a prompt based on the given query and context, ready to be passed to an LLM
|
||||
@@ -211,7 +209,7 @@ class EmbedChain:
|
||||
:param context: Similar documents to the query used as context.
|
||||
:return: The answer.
|
||||
"""
|
||||
answer = self.get_openai_answer(prompt)
|
||||
answer = self.get_llm_model_answer(prompt)
|
||||
return answer
|
||||
|
||||
def query(self, input_query):
|
||||
@@ -237,4 +235,50 @@ class App(EmbedChain):
|
||||
adds(data_type, url): adds the data from the given URL to the vector db.
|
||||
query(query): finds answer to the given query using vector database and LLM.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __int__(self, db=None, ef=None):
|
||||
if ef is None:
|
||||
ef = openai_ef
|
||||
super().__init__(db, ef)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
messages = []
|
||||
messages.append({
|
||||
"role": "user", "content": prompt
|
||||
})
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo-0613",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=1000,
|
||||
top_p=1,
|
||||
)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
class OpenSourceApp(EmbedChain):
|
||||
"""
|
||||
The OpenSource app.
|
||||
Same as App, but uses an open source embedding model and LLM.
|
||||
|
||||
Has two function: add and query.
|
||||
|
||||
adds(data_type, url): adds the data from the given URL to the vector db.
|
||||
query(query): finds answer to the given query using vector database and LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, db=None, ef=None):
|
||||
print("Loading open source embedding model. This may take some time...")
|
||||
if ef is None:
|
||||
ef = sentence_transformer_ef
|
||||
print("Successfully loaded open source embedding model.")
|
||||
super().__init__(db, ef)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
global gpt4all_model
|
||||
if gpt4all_model is None:
|
||||
gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
||||
response = gpt4all_model.generate(
|
||||
prompt=prompt,
|
||||
)
|
||||
return response
|
||||
@@ -12,7 +12,8 @@ openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
)
|
||||
|
||||
class ChromaDB(BaseVectorDB):
|
||||
def __init__(self, db_dir=None):
|
||||
def __init__(self, db_dir=None, ef=None):
|
||||
self.ef = ef if ef is not None else openai_ef
|
||||
if db_dir is None:
|
||||
db_dir = "db"
|
||||
self.client_settings = chromadb.config.Settings(
|
||||
@@ -27,5 +28,5 @@ class ChromaDB(BaseVectorDB):
|
||||
|
||||
def _get_or_create_collection(self):
|
||||
return self.client.get_or_create_collection(
|
||||
'embedchain_store', embedding_function=openai_ef,
|
||||
'embedchain_store', embedding_function=self.ef,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user