[Improvement] Add support for gpt4all through langchain (#838)

This commit is contained in:
Deven Patel
2023-10-25 22:25:00 -07:00
committed by GitHub
parent d77e8da3f3
commit 0f8a2e624a
8 changed files with 35 additions and 28 deletions

View File

@@ -1,7 +1,5 @@
from typing import Optional
from chromadb.utils import embedding_functions
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
@@ -9,12 +7,13 @@ from embedchain.models import VectorDimensions
class GPT4AllEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
# Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
super().__init__(config=config)
if self.config.model is None:
self.config.model = "all-MiniLM-L6-v2"
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model)
from langchain.embeddings import \
GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
embeddings = LangchainGPT4AllEmbeddings()
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = VectorDimensions.GPT4ALL.value

View File

@@ -1,5 +1,8 @@
from typing import Iterable, Optional, Union
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@@ -12,6 +15,7 @@ class GPT4ALLLlm(BaseLlm):
if self.config.model is None:
self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
self.instance = GPT4ALLLlm._get_instance(self.config.model)
self.instance.streaming = config.stream
def get_llm_model_answer(self, prompt):
return self._get_answer(prompt=prompt, config=self.config)
@@ -19,13 +23,13 @@ class GPT4ALLLlm(BaseLlm):
@staticmethod
def _get_instance(model):
try:
from gpt4all import GPT4All
from langchain.llms.gpt4all import GPT4All as LangchainGPT4All
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
) from None
return GPT4All(model_name=model)
return LangchainGPT4All(model=model, allow_download=True)
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
if config.model and config.model != self.config.model:
@@ -33,14 +37,25 @@ class GPT4ALLLlm(BaseLlm):
"GPT4ALLLlm does not support switching models at runtime. Please create a new app instance."
)
messages = []
if config.system_prompt:
raise ValueError("GPT4ALLLlm does not support `system_prompt`")
messages.append(config.system_prompt)
messages.append(prompt)
kwargs = {
"temp": config.temperature,
"max_tokens": config.max_tokens,
}
if config.top_p:
kwargs["top_p"] = config.top_p
response = self.instance.generate(
prompt=prompt,
streaming=config.stream,
top_p=config.top_p,
max_tokens=config.max_tokens,
temp=config.temperature,
)
return response
callbacks = None
if config.stream:
callbacks = [StreamingStdOutCallbackHandler()]
else:
callbacks =[StdOutCallbackHandler()]
response = self.instance.generate(prompts=messages, callbacks=callbacks, **kwargs)
answer = ""
for generations in response.generations:
answer += " ".join(map(lambda generation: generation.text, generations))
return answer