[Improvement] Add support for gpt4all through langchain (#838)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
llm:
|
llm:
|
||||||
provider: gpt4all
|
provider: gpt4all
|
||||||
model: 'orca-mini-3b.ggmlv3.q4_0.bin'
|
|
||||||
config:
|
config:
|
||||||
|
model: 'orca-mini-3b.ggmlv3.q4_0.bin'
|
||||||
temperature: 0.5
|
temperature: 0.5
|
||||||
max_tokens: 1000
|
max_tokens: 1000
|
||||||
top_p: 1
|
top_p: 1
|
||||||
@@ -9,5 +9,3 @@ llm:
|
|||||||
|
|
||||||
embedder:
|
embedder:
|
||||||
provider: gpt4all
|
provider: gpt4all
|
||||||
config:
|
|
||||||
model: 'all-MiniLM-L6-v2'
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ app:
|
|||||||
|
|
||||||
llm:
|
llm:
|
||||||
provider: gpt4all
|
provider: gpt4all
|
||||||
model: 'orca-mini-3b.ggmlv3.q4_0.bin'
|
|
||||||
config:
|
config:
|
||||||
|
model: 'orca-mini-3b.ggmlv3.q4_0.bin'
|
||||||
temperature: 0.5
|
temperature: 0.5
|
||||||
max_tokens: 1000
|
max_tokens: 1000
|
||||||
top_p: 1
|
top_p: 1
|
||||||
@@ -23,5 +23,4 @@ vectordb:
|
|||||||
embedder:
|
embedder:
|
||||||
provider: gpt4all
|
provider: gpt4all
|
||||||
config:
|
config:
|
||||||
model: 'all-MiniLM-L6-v2'
|
|
||||||
deployment_name: null
|
deployment_name: null
|
||||||
|
|||||||
@@ -108,8 +108,6 @@ llm:
|
|||||||
|
|
||||||
embedder:
|
embedder:
|
||||||
provider: gpt4all
|
provider: gpt4all
|
||||||
config:
|
|
||||||
model: 'all-MiniLM-L6-v2'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|||||||
@@ -198,8 +198,6 @@ llm:
|
|||||||
|
|
||||||
embedder:
|
embedder:
|
||||||
provider: gpt4all
|
provider: gpt4all
|
||||||
config:
|
|
||||||
model: 'all-MiniLM-L6-v2'
|
|
||||||
```
|
```
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from chromadb.utils import embedding_functions
|
|
||||||
|
|
||||||
from embedchain.config import BaseEmbedderConfig
|
from embedchain.config import BaseEmbedderConfig
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.models import VectorDimensions
|
from embedchain.models import VectorDimensions
|
||||||
@@ -9,12 +7,13 @@ from embedchain.models import VectorDimensions
|
|||||||
|
|
||||||
class GPT4AllEmbedder(BaseEmbedder):
|
class GPT4AllEmbedder(BaseEmbedder):
|
||||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||||
# Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
|
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
if self.config.model is None:
|
|
||||||
self.config.model = "all-MiniLM-L6-v2"
|
|
||||||
|
|
||||||
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model)
|
from langchain.embeddings import \
|
||||||
|
GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
|
||||||
|
|
||||||
|
embeddings = LangchainGPT4AllEmbeddings()
|
||||||
|
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||||
|
|
||||||
vector_dimension = VectorDimensions.GPT4ALL.value
|
vector_dimension = VectorDimensions.GPT4ALL.value
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
from typing import Iterable, Optional, Union
|
from typing import Iterable, Optional, Union
|
||||||
|
|
||||||
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
|
||||||
from embedchain.config import BaseLlmConfig
|
from embedchain.config import BaseLlmConfig
|
||||||
from embedchain.helper.json_serializable import register_deserializable
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
from embedchain.llm.base import BaseLlm
|
from embedchain.llm.base import BaseLlm
|
||||||
@@ -12,6 +15,7 @@ class GPT4ALLLlm(BaseLlm):
|
|||||||
if self.config.model is None:
|
if self.config.model is None:
|
||||||
self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
|
self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
|
||||||
self.instance = GPT4ALLLlm._get_instance(self.config.model)
|
self.instance = GPT4ALLLlm._get_instance(self.config.model)
|
||||||
|
self.instance.streaming = config.stream
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
return self._get_answer(prompt=prompt, config=self.config)
|
return self._get_answer(prompt=prompt, config=self.config)
|
||||||
@@ -19,13 +23,13 @@ class GPT4ALLLlm(BaseLlm):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_instance(model):
|
def _get_instance(model):
|
||||||
try:
|
try:
|
||||||
from gpt4all import GPT4All
|
from langchain.llms.gpt4all import GPT4All as LangchainGPT4All
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
|
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
return GPT4All(model_name=model)
|
return LangchainGPT4All(model=model, allow_download=True)
|
||||||
|
|
||||||
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
|
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
|
||||||
if config.model and config.model != self.config.model:
|
if config.model and config.model != self.config.model:
|
||||||
@@ -33,14 +37,25 @@ class GPT4ALLLlm(BaseLlm):
|
|||||||
"GPT4ALLLlm does not support switching models at runtime. Please create a new app instance."
|
"GPT4ALLLlm does not support switching models at runtime. Please create a new app instance."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
messages = []
|
||||||
if config.system_prompt:
|
if config.system_prompt:
|
||||||
raise ValueError("GPT4ALLLlm does not support `system_prompt`")
|
messages.append(config.system_prompt)
|
||||||
|
messages.append(prompt)
|
||||||
|
kwargs = {
|
||||||
|
"temp": config.temperature,
|
||||||
|
"max_tokens": config.max_tokens,
|
||||||
|
}
|
||||||
|
if config.top_p:
|
||||||
|
kwargs["top_p"] = config.top_p
|
||||||
|
|
||||||
response = self.instance.generate(
|
callbacks = None
|
||||||
prompt=prompt,
|
if config.stream:
|
||||||
streaming=config.stream,
|
callbacks = [StreamingStdOutCallbackHandler()]
|
||||||
top_p=config.top_p,
|
else:
|
||||||
max_tokens=config.max_tokens,
|
callbacks =[StdOutCallbackHandler()]
|
||||||
temp=config.temperature,
|
|
||||||
)
|
response = self.instance.generate(prompts=messages, callbacks=callbacks, **kwargs)
|
||||||
return response
|
answer = ""
|
||||||
|
for generations in response.generations:
|
||||||
|
answer += " ".join(map(lambda generation: generation.text, generations))
|
||||||
|
return answer
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ pytest-asyncio = "^0.21.1"
|
|||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
streamlit = ["streamlit"]
|
streamlit = ["streamlit"]
|
||||||
community = ["llama-hub"]
|
community = ["llama-hub"]
|
||||||
opensource = ["sentence-transformers", "torch", "gpt4all"]
|
opensource = ["sentence-transformers", "torch", "gpt4all", "langchain"]
|
||||||
elasticsearch = ["elasticsearch"]
|
elasticsearch = ["elasticsearch"]
|
||||||
opensearch = ["opensearch-py"]
|
opensearch = ["opensearch-py"]
|
||||||
poe = ["fastapi-poe"]
|
poe = ["fastapi-poe"]
|
||||||
|
|||||||
@@ -135,6 +135,7 @@ class TestAppFromConfig:
|
|||||||
|
|
||||||
# Validate the LLM config values
|
# Validate the LLM config values
|
||||||
llm_config = config_data["llm"]["config"]
|
llm_config = config_data["llm"]["config"]
|
||||||
|
assert app.llm.config.model == llm_config["model"]
|
||||||
assert app.llm.config.temperature == llm_config["temperature"]
|
assert app.llm.config.temperature == llm_config["temperature"]
|
||||||
assert app.llm.config.max_tokens == llm_config["max_tokens"]
|
assert app.llm.config.max_tokens == llm_config["max_tokens"]
|
||||||
assert app.llm.config.top_p == llm_config["top_p"]
|
assert app.llm.config.top_p == llm_config["top_p"]
|
||||||
@@ -148,5 +149,4 @@ class TestAppFromConfig:
|
|||||||
|
|
||||||
# Validate the Embedder config values
|
# Validate the Embedder config values
|
||||||
embedder_config = config_data["embedder"]["config"]
|
embedder_config = config_data["embedder"]["config"]
|
||||||
assert app.embedder.config.model == embedder_config["model"]
|
|
||||||
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
|
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
|
||||||
|
|||||||
Reference in New Issue
Block a user