From 0f8a2e624ad948040751e6567000090e356afcd3 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Wed, 25 Oct 2023 22:25:00 -0700 Subject: [PATCH] [Improvement] Add support for gpt4all through langchain (#838) --- configs/gpt4all.yaml | 4 +-- configs/opensource.yaml | 3 +-- docs/components/embedding-models.mdx | 2 -- docs/components/llms.mdx | 2 -- embedchain/embedder/gpt4all.py | 11 ++++----- embedchain/llm/gpt4all.py | 37 +++++++++++++++++++--------- pyproject.toml | 2 +- tests/apps/test_apps.py | 2 +- 8 files changed, 35 insertions(+), 28 deletions(-) diff --git a/configs/gpt4all.yaml b/configs/gpt4all.yaml index e7bba542..5cc66575 100644 --- a/configs/gpt4all.yaml +++ b/configs/gpt4all.yaml @@ -1,7 +1,7 @@ llm: provider: gpt4all - model: 'orca-mini-3b.ggmlv3.q4_0.bin' config: + model: 'orca-mini-3b.ggmlv3.q4_0.bin' temperature: 0.5 max_tokens: 1000 top_p: 1 @@ -9,5 +9,3 @@ llm: embedder: provider: gpt4all - config: - model: 'all-MiniLM-L6-v2' diff --git a/configs/opensource.yaml b/configs/opensource.yaml index 5e8e9c60..44bd3570 100644 --- a/configs/opensource.yaml +++ b/configs/opensource.yaml @@ -6,8 +6,8 @@ app: llm: provider: gpt4all - model: 'orca-mini-3b.ggmlv3.q4_0.bin' config: + model: 'orca-mini-3b.ggmlv3.q4_0.bin' temperature: 0.5 max_tokens: 1000 top_p: 1 @@ -23,5 +23,4 @@ vectordb: embedder: provider: gpt4all config: - model: 'all-MiniLM-L6-v2' deployment_name: null diff --git a/docs/components/embedding-models.mdx b/docs/components/embedding-models.mdx index 9f5ab5aa..0f030bdc 100644 --- a/docs/components/embedding-models.mdx +++ b/docs/components/embedding-models.mdx @@ -108,8 +108,6 @@ llm: embedder: provider: gpt4all - config: - model: 'all-MiniLM-L6-v2' ``` diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx index 328e05cd..49b5d7cd 100644 --- a/docs/components/llms.mdx +++ b/docs/components/llms.mdx @@ -198,8 +198,6 @@ llm: embedder: provider: gpt4all - config: - model: 'all-MiniLM-L6-v2' ``` diff --git a/embedchain/embedder/gpt4all.py b/embedchain/embedder/gpt4all.py index f12b2bcf..d078825d 100644 --- a/embedchain/embedder/gpt4all.py +++ b/embedchain/embedder/gpt4all.py @@ -1,7 +1,5 @@ from typing import Optional -from chromadb.utils import embedding_functions - from embedchain.config import BaseEmbedderConfig from embedchain.embedder.base import BaseEmbedder from embedchain.models import VectorDimensions @@ -9,12 +7,13 @@ from embedchain.models import VectorDimensions class GPT4AllEmbedder(BaseEmbedder): def __init__(self, config: Optional[BaseEmbedderConfig] = None): - # Note: We could use langchains GPT4ALL embedding, but it's not available in all versions. super().__init__(config=config) - if self.config.model is None: - self.config.model = "all-MiniLM-L6-v2" - embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model) + from langchain.embeddings import \ + GPT4AllEmbeddings as LangchainGPT4AllEmbeddings + + embeddings = LangchainGPT4AllEmbeddings() + embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) self.set_embedding_fn(embedding_fn=embedding_fn) vector_dimension = VectorDimensions.GPT4ALL.value diff --git a/embedchain/llm/gpt4all.py b/embedchain/llm/gpt4all.py index 2fa68eca..99d0e15f 100644 --- a/embedchain/llm/gpt4all.py +++ b/embedchain/llm/gpt4all.py @@ -1,5 +1,8 @@ from typing import Iterable, Optional, Union +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.callbacks.stdout import StdOutCallbackHandler + from embedchain.config import BaseLlmConfig from embedchain.helper.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm @@ -12,6 +15,7 @@ class GPT4ALLLlm(BaseLlm): if self.config.model is None: self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin" self.instance = GPT4ALLLlm._get_instance(self.config.model) + self.instance.streaming = config.stream def get_llm_model_answer(self, prompt): return self._get_answer(prompt=prompt, config=self.config) @@ -19,13 +23,13 @@ class GPT4ALLLlm(BaseLlm): @staticmethod def _get_instance(model): try: - from gpt4all import GPT4All + from langchain.llms.gpt4all import GPT4All as LangchainGPT4All except ModuleNotFoundError: raise ModuleNotFoundError( "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501 ) from None - return GPT4All(model_name=model) + return LangchainGPT4All(model=model, allow_download=True) def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]: if config.model and config.model != self.config.model: @@ -33,14 +37,25 @@ class GPT4ALLLlm(BaseLlm): "GPT4ALLLlm does not support switching models at runtime. Please create a new app instance." ) + messages = [] if config.system_prompt: - raise ValueError("GPT4ALLLlm does not support `system_prompt`") + messages.append(config.system_prompt) + messages.append(prompt) + kwargs = { + "temp": config.temperature, + "max_tokens": config.max_tokens, + } + if config.top_p: + kwargs["top_p"] = config.top_p - response = self.instance.generate( - prompt=prompt, - streaming=config.stream, - top_p=config.top_p, - max_tokens=config.max_tokens, - temp=config.temperature, - ) - return response + callbacks = None + if config.stream: + callbacks = [StreamingStdOutCallbackHandler()] + else: + callbacks =[StdOutCallbackHandler()] + + response = self.instance.generate(prompts=messages, callbacks=callbacks, **kwargs) + answer = "" + for generations in response.generations: + answer += " ".join(map(lambda generation: generation.text, generations)) + return answer diff --git a/pyproject.toml b/pyproject.toml index fabad3c1..b0de3093 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,7 +143,7 @@ pytest-asyncio = "^0.21.1" [tool.poetry.extras] streamlit = ["streamlit"] community = ["llama-hub"] -opensource = ["sentence-transformers", "torch", "gpt4all"] +opensource = ["sentence-transformers", "torch", "gpt4all", "langchain"] elasticsearch = ["elasticsearch"] opensearch = ["opensearch-py"] poe = ["fastapi-poe"] diff --git a/tests/apps/test_apps.py b/tests/apps/test_apps.py index 1ffe2b20..32632747 100644 --- a/tests/apps/test_apps.py +++ b/tests/apps/test_apps.py @@ -135,6 +135,7 @@ class TestAppFromConfig: # Validate the LLM config values llm_config = config_data["llm"]["config"] + assert app.llm.config.model == llm_config["model"] assert app.llm.config.temperature == llm_config["temperature"] assert app.llm.config.max_tokens == llm_config["max_tokens"] assert app.llm.config.top_p == llm_config["top_p"] @@ -148,5 +149,4 @@ class TestAppFromConfig: # Validate the Embedder config values embedder_config = config_data["embedder"]["config"] - assert app.embedder.config.model == embedder_config["model"] assert app.embedder.config.deployment_name == embedder_config["deployment_name"]