Open source embedding and LLM models (#133)
* Add open source LLM model: gpt4all * Add open source embedding model: sentence transformers
This commit is contained in:
@@ -1 +1 @@
|
|||||||
from .embedchain import App
|
from .embedchain import App, OpenSourceApp
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
import openai
|
import openai
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from chromadb.utils import embedding_functions
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from gpt4all import GPT4All
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
@@ -17,16 +19,23 @@ from embedchain.chunkers.qna_pair import QnaPairChunker
|
|||||||
from embedchain.chunkers.text import TextChunker
|
from embedchain.chunkers.text import TextChunker
|
||||||
from embedchain.vectordb.chroma_db import ChromaDB
|
from embedchain.vectordb.chroma_db import ChromaDB
|
||||||
|
|
||||||
load_dotenv()
|
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||||
|
model_name="text-embedding-ada-002"
|
||||||
|
)
|
||||||
|
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
|
||||||
|
|
||||||
embeddings = OpenAIEmbeddings()
|
gpt4all_model = None
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
ABS_PATH = os.getcwd()
|
ABS_PATH = os.getcwd()
|
||||||
DB_DIR = os.path.join(ABS_PATH, "db")
|
DB_DIR = os.path.join(ABS_PATH, "db")
|
||||||
|
|
||||||
|
|
||||||
class EmbedChain:
|
class EmbedChain:
|
||||||
def __init__(self, db=None):
|
def __init__(self, db=None, ef=None):
|
||||||
"""
|
"""
|
||||||
Initializes the EmbedChain instance, sets up a vector DB client and
|
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||||
creates a collection.
|
creates a collection.
|
||||||
@@ -34,7 +43,7 @@ class EmbedChain:
|
|||||||
:param db: The instance of the VectorDB subclass.
|
:param db: The instance of the VectorDB subclass.
|
||||||
"""
|
"""
|
||||||
if db is None:
|
if db is None:
|
||||||
db = ChromaDB()
|
db = ChromaDB(ef=ef)
|
||||||
self.db_client = db.client
|
self.db_client = db.client
|
||||||
self.collection = db.collection
|
self.collection = db.collection
|
||||||
self.user_asks = []
|
self.user_asks = []
|
||||||
@@ -154,20 +163,9 @@ class EmbedChain:
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_openai_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
messages = []
|
raise NotImplementedError
|
||||||
messages.append({
|
|
||||||
"role": "user", "content": prompt
|
|
||||||
})
|
|
||||||
response = openai.ChatCompletion.create(
|
|
||||||
model="gpt-3.5-turbo-0613",
|
|
||||||
messages=messages,
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=1000,
|
|
||||||
top_p=1,
|
|
||||||
)
|
|
||||||
return response["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
def retrieve_from_database(self, input_query):
|
def retrieve_from_database(self, input_query):
|
||||||
"""
|
"""
|
||||||
Queries the vector database based on the given input query.
|
Queries the vector database based on the given input query.
|
||||||
@@ -186,7 +184,7 @@ class EmbedChain:
|
|||||||
else:
|
else:
|
||||||
content = ""
|
content = ""
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def generate_prompt(self, input_query, context):
|
def generate_prompt(self, input_query, context):
|
||||||
"""
|
"""
|
||||||
Generates a prompt based on the given query and context, ready to be passed to an LLM
|
Generates a prompt based on the given query and context, ready to be passed to an LLM
|
||||||
@@ -211,7 +209,7 @@ class EmbedChain:
|
|||||||
:param context: Similar documents to the query used as context.
|
:param context: Similar documents to the query used as context.
|
||||||
:return: The answer.
|
:return: The answer.
|
||||||
"""
|
"""
|
||||||
answer = self.get_openai_answer(prompt)
|
answer = self.get_llm_model_answer(prompt)
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
def query(self, input_query):
|
def query(self, input_query):
|
||||||
@@ -237,4 +235,50 @@ class App(EmbedChain):
|
|||||||
adds(data_type, url): adds the data from the given URL to the vector db.
|
adds(data_type, url): adds the data from the given URL to the vector db.
|
||||||
query(query): finds answer to the given query using vector database and LLM.
|
query(query): finds answer to the given query using vector database and LLM.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
def __int__(self, db=None, ef=None):
|
||||||
|
if ef is None:
|
||||||
|
ef = openai_ef
|
||||||
|
super().__init__(db, ef)
|
||||||
|
|
||||||
|
def get_llm_model_answer(self, prompt):
|
||||||
|
messages = []
|
||||||
|
messages.append({
|
||||||
|
"role": "user", "content": prompt
|
||||||
|
})
|
||||||
|
response = openai.ChatCompletion.create(
|
||||||
|
model="gpt-3.5-turbo-0613",
|
||||||
|
messages=messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=1000,
|
||||||
|
top_p=1,
|
||||||
|
)
|
||||||
|
return response["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenSourceApp(EmbedChain):
|
||||||
|
"""
|
||||||
|
The OpenSource app.
|
||||||
|
Same as App, but uses an open source embedding model and LLM.
|
||||||
|
|
||||||
|
Has two function: add and query.
|
||||||
|
|
||||||
|
adds(data_type, url): adds the data from the given URL to the vector db.
|
||||||
|
query(query): finds answer to the given query using vector database and LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db=None, ef=None):
|
||||||
|
print("Loading open source embedding model. This may take some time...")
|
||||||
|
if ef is None:
|
||||||
|
ef = sentence_transformer_ef
|
||||||
|
print("Successfully loaded open source embedding model.")
|
||||||
|
super().__init__(db, ef)
|
||||||
|
|
||||||
|
def get_llm_model_answer(self, prompt):
|
||||||
|
global gpt4all_model
|
||||||
|
if gpt4all_model is None:
|
||||||
|
gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
||||||
|
response = gpt4all_model.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
return response
|
||||||
@@ -12,7 +12,8 @@ openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
|||||||
)
|
)
|
||||||
|
|
||||||
class ChromaDB(BaseVectorDB):
|
class ChromaDB(BaseVectorDB):
|
||||||
def __init__(self, db_dir=None):
|
def __init__(self, db_dir=None, ef=None):
|
||||||
|
self.ef = ef if ef is not None else openai_ef
|
||||||
if db_dir is None:
|
if db_dir is None:
|
||||||
db_dir = "db"
|
db_dir = "db"
|
||||||
self.client_settings = chromadb.config.Settings(
|
self.client_settings = chromadb.config.Settings(
|
||||||
@@ -27,5 +28,5 @@ class ChromaDB(BaseVectorDB):
|
|||||||
|
|
||||||
def _get_or_create_collection(self):
|
def _get_or_create_collection(self):
|
||||||
return self.client.get_or_create_collection(
|
return self.client.get_or_create_collection(
|
||||||
'embedchain_store', embedding_function=openai_ef,
|
'embedchain_store', embedding_function=self.ef,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user