Open source embedding and LLM models (#133)

* Add open source LLM model: gpt4all
* Add open source embedding model: sentence transformers
This commit is contained in:
Taranjeet Singh
2023-07-05 02:23:23 +05:30
committed by GitHub
parent 3461ef4b14
commit cf1e000fb3
4 changed files with 71 additions and 24 deletions

View File

@@ -1 +1 @@
from .embedchain import App
from .embedchain import App, OpenSourceApp

View File

@@ -1,7 +1,9 @@
import openai
import os
from chromadb.utils import embedding_functions
from dotenv import load_dotenv
from gpt4all import GPT4All
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
@@ -17,16 +19,23 @@ from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.chunkers.text import TextChunker
from embedchain.vectordb.chroma_db import ChromaDB
load_dotenv()
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name="text-embedding-ada-002"
)
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
embeddings = OpenAIEmbeddings()
gpt4all_model = None
load_dotenv()
ABS_PATH = os.getcwd()
DB_DIR = os.path.join(ABS_PATH, "db")
class EmbedChain:
def __init__(self, db=None):
def __init__(self, db=None, ef=None):
"""
Initializes the EmbedChain instance, sets up a vector DB client and
creates a collection.
@@ -34,7 +43,7 @@ class EmbedChain:
:param db: The instance of the VectorDB subclass.
"""
if db is None:
db = ChromaDB()
db = ChromaDB(ef=ef)
self.db_client = db.client
self.collection = db.collection
self.user_asks = []
@@ -154,20 +163,9 @@ class EmbedChain:
)
]
def get_openai_answer(self, prompt):
messages = []
messages.append({
"role": "user", "content": prompt
})
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
messages=messages,
temperature=0,
max_tokens=1000,
top_p=1,
)
return response["choices"][0]["message"]["content"]
def get_llm_model_answer(self, prompt):
raise NotImplementedError
def retrieve_from_database(self, input_query):
"""
Queries the vector database based on the given input query.
@@ -186,7 +184,7 @@ class EmbedChain:
else:
content = ""
return content
def generate_prompt(self, input_query, context):
"""
Generates a prompt based on the given query and context, ready to be passed to an LLM
@@ -211,7 +209,7 @@ class EmbedChain:
:param context: Similar documents to the query used as context.
:return: The answer.
"""
answer = self.get_openai_answer(prompt)
answer = self.get_llm_model_answer(prompt)
return answer
def query(self, input_query):
@@ -237,4 +235,50 @@ class App(EmbedChain):
adds(data_type, url): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM.
"""
pass
def __int__(self, db=None, ef=None):
if ef is None:
ef = openai_ef
super().__init__(db, ef)
def get_llm_model_answer(self, prompt):
messages = []
messages.append({
"role": "user", "content": prompt
})
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
messages=messages,
temperature=0,
max_tokens=1000,
top_p=1,
)
return response["choices"][0]["message"]["content"]
class OpenSourceApp(EmbedChain):
"""
The OpenSource app.
Same as App, but uses an open source embedding model and LLM.
Has two function: add and query.
adds(data_type, url): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM.
"""
def __init__(self, db=None, ef=None):
print("Loading open source embedding model. This may take some time...")
if ef is None:
ef = sentence_transformer_ef
print("Successfully loaded open source embedding model.")
super().__init__(db, ef)
def get_llm_model_answer(self, prompt):
global gpt4all_model
if gpt4all_model is None:
gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
response = gpt4all_model.generate(
prompt=prompt,
)
return response

View File

@@ -12,7 +12,8 @@ openai_ef = embedding_functions.OpenAIEmbeddingFunction(
)
class ChromaDB(BaseVectorDB):
def __init__(self, db_dir=None):
def __init__(self, db_dir=None, ef=None):
self.ef = ef if ef is not None else openai_ef
if db_dir is None:
db_dir = "db"
self.client_settings = chromadb.config.Settings(
@@ -27,5 +28,5 @@ class ChromaDB(BaseVectorDB):
def _get_or_create_collection(self):
return self.client.get_or_create_collection(
'embedchain_store', embedding_function=openai_ef,
'embedchain_store', embedding_function=self.ef,
)