[Feature] Add support for vllm as llm source (#1149)

This commit is contained in:
Deshraj Yadav
2024-01-09 17:38:53 +05:30
committed by GitHub
parent 5f653e69ae
commit 0373fa231c
9 changed files with 111 additions and 15 deletions

14
configs/vllm.yaml Normal file
View File

@@ -0,0 +1,14 @@
llm:
provider: vllm
config:
model: 'meta-llama/Llama-2-70b-hf'
temperature: 0.5
top_p: 1
top_k: 10
stream: true
trust_remote_code: true
embedder:
provider: huggingface
config:
model: 'BAAI/bge-small-en-v1.5'

View File

@@ -14,6 +14,7 @@ Embedchain comes with built-in support for various popular large language models
<Card title="Cohere" href="#cohere"></Card> <Card title="Cohere" href="#cohere"></Card>
<Card title="Together" href="#together"></Card> <Card title="Together" href="#together"></Card>
<Card title="Ollama" href="#ollama"></Card> <Card title="Ollama" href="#ollama"></Card>
<Card title="vLLM" href="#vllm"></Card>
<Card title="GPT4All" href="#gpt4all"></Card> <Card title="GPT4All" href="#gpt4all"></Card>
<Card title="JinaChat" href="#jinachat"></Card> <Card title="JinaChat" href="#jinachat"></Card>
<Card title="Hugging Face" href="#hugging-face"></Card> <Card title="Hugging Face" href="#hugging-face"></Card>
@@ -393,6 +394,34 @@ llm:
</CodeGroup> </CodeGroup>
## Ollama
Setup vLLM by following instructions given in [their docs](https://docs.vllm.ai/en/latest/getting_started/installation.html).
<CodeGroup>
```python main.py
import os
from embedchain import App
# load llm configuration from config.yaml file
app = App.from_config(config_path="config.yaml")
```
```yaml config.yaml
llm:
provider: vllm
config:
model: 'meta-llama/Llama-2-70b-hf'
temperature: 0.5
top_p: 1
top_k: 10
stream: true
trust_remote_code: true
```
</CodeGroup>
## GPT4ALL ## GPT4ALL
Install related dependencies using the following command: Install related dependencies using the following command:
@@ -515,7 +544,7 @@ app = App.from_config(config_path="config.yaml")
```yaml config.yaml ```yaml config.yaml
llm: llm:
provider: huggingface provider: huggingface
config: config:
endpoint: https://api-inference.huggingface.co/models/gpt2 # replace with your personal endpoint endpoint: https://api-inference.huggingface.co/models/gpt2 # replace with your personal endpoint
``` ```
@@ -525,7 +554,7 @@ If your endpoint requires additional parameters, you can pass them in the `model
``` ```
llm: llm:
provider: huggingface provider: huggingface
config: config:
endpoint: <YOUR_ENDPOINT_URL_HERE> endpoint: <YOUR_ENDPOINT_URL_HERE>
model_kwargs: model_kwargs:

View File

@@ -9,14 +9,9 @@ from typing import Any, Dict, Optional
import requests import requests
import yaml import yaml
from embedchain.cache import ( from embedchain.cache import (Config, ExactMatchEvaluation,
Config, SearchDistanceEvaluation, cache,
ExactMatchEvaluation, gptcache_data_manager, gptcache_pre_function)
SearchDistanceEvaluation,
cache,
gptcache_data_manager,
gptcache_pre_function,
)
from embedchain.client import Client from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.constants import SQLITE_PATH from embedchain.constants import SQLITE_PATH

View File

@@ -73,7 +73,7 @@ class BaseLlmConfig(BaseConfig):
callbacks: Optional[List] = None, callbacks: Optional[List] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
model_kwargs: Optional[Dict[str, Any]] = {}, model_kwargs: Optional[Dict[str, Any]] = None,
): ):
""" """
Initializes a configuration class instance for the LLM. Initializes a configuration class instance for the LLM.
@@ -115,6 +115,8 @@ class BaseLlmConfig(BaseConfig):
:type model_kwargs: Optional[Dict[str, Any]], optional :type model_kwargs: Optional[Dict[str, Any]], optional
:param callbacks: Langchain callback functions to use, defaults to None :param callbacks: Langchain callback functions to use, defaults to None
:type callbacks: Optional[List], optional :type callbacks: Optional[List], optional
:param query_type: The type of query to use, defaults to None
:type query_type: Optional[str], optional
:raises ValueError: If the template is not valid as template should :raises ValueError: If the template is not valid as template should
contain $context and $query (and optionally $history) contain $context and $query (and optionally $history)
:raises ValueError: Stream is not boolean :raises ValueError: Stream is not boolean
@@ -142,6 +144,7 @@ class BaseLlmConfig(BaseConfig):
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.model_kwargs = model_kwargs self.model_kwargs = model_kwargs
if type(prompt) is str: if type(prompt) is str:
prompt = Template(prompt) prompt = Template(prompt)

View File

@@ -7,7 +7,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback from embedchain.cache import (
adapt,
get_gptcache_session,
gptcache_data_convert,
gptcache_update_cache_callback,
)
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig from embedchain.config.base_app_config import BaseAppConfig

View File

@@ -4,7 +4,9 @@ from typing import Any, Dict, Generator, List, Optional
from langchain.schema import BaseMessage as LCBaseMessage from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage from embedchain.memory.message import ChatMessage

40
embedchain/llm/vllm.py Normal file
View File

@@ -0,0 +1,40 @@
from typing import Iterable, Optional, Union
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.llms import VLLM as BaseVLLM
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@register_deserializable
class VLLM(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
if self.config.model is None:
self.config.model = "mosaicml/mpt-7b"
def get_llm_model_answer(self, prompt):
return self._get_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
# Prepare the arguments for BaseVLLM
llm_args = {
"model": config.model,
"temperature": config.temperature,
"top_p": config.top_p,
"callback_manager": CallbackManager(callback_manager),
}
# Add model_kwargs if they are not None
if config.model_kwargs is not None:
llm_args.update(config.model_kwargs)
llm = BaseVLLM(**llm_args)
return llm(prompt)

View File

@@ -6,7 +6,15 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
try: try:
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusClient, connections, utility from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusClient,
connections,
utility,
)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`" "Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.1.57" version = "0.1.58"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = [ authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>", "Taranjeet Singh <taranjeet@embedchain.ai>",