diff --git a/configs/vllm.yaml b/configs/vllm.yaml new file mode 100644 index 00000000..536a589a --- /dev/null +++ b/configs/vllm.yaml @@ -0,0 +1,14 @@ +llm: + provider: vllm + config: + model: 'meta-llama/Llama-2-70b-hf' + temperature: 0.5 + top_p: 1 + top_k: 10 + stream: true + trust_remote_code: true + +embedder: + provider: huggingface + config: + model: 'BAAI/bge-small-en-v1.5' diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx index faf64855..3825d017 100644 --- a/docs/components/llms.mdx +++ b/docs/components/llms.mdx @@ -14,6 +14,7 @@ Embedchain comes with built-in support for various popular large language models + @@ -393,6 +394,34 @@ llm: +## Ollama + +Setup vLLM by following instructions given in [their docs](https://docs.vllm.ai/en/latest/getting_started/installation.html). + + + +```python main.py +import os +from embedchain import App + +# load llm configuration from config.yaml file +app = App.from_config(config_path="config.yaml") +``` + +```yaml config.yaml +llm: + provider: vllm + config: + model: 'meta-llama/Llama-2-70b-hf' + temperature: 0.5 + top_p: 1 + top_k: 10 + stream: true + trust_remote_code: true +``` + + + ## GPT4ALL Install related dependencies using the following command: @@ -515,7 +544,7 @@ app = App.from_config(config_path="config.yaml") ```yaml config.yaml llm: - provider: huggingface + provider: huggingface config: endpoint: https://api-inference.huggingface.co/models/gpt2 # replace with your personal endpoint ``` @@ -525,7 +554,7 @@ If your endpoint requires additional parameters, you can pass them in the `model ``` llm: - provider: huggingface + provider: huggingface config: endpoint: model_kwargs: diff --git a/embedchain/app.py b/embedchain/app.py index ab3643d3..45f422c9 100644 --- a/embedchain/app.py +++ b/embedchain/app.py @@ -9,14 +9,9 @@ from typing import Any, Dict, Optional import requests import yaml -from embedchain.cache import ( - Config, - ExactMatchEvaluation, - SearchDistanceEvaluation, - cache, - gptcache_data_manager, - gptcache_pre_function, -) +from embedchain.cache import (Config, ExactMatchEvaluation, + SearchDistanceEvaluation, cache, + gptcache_data_manager, gptcache_pre_function) from embedchain.client import Client from embedchain.config import AppConfig, CacheConfig, ChunkerConfig from embedchain.constants import SQLITE_PATH diff --git a/embedchain/config/llm/base.py b/embedchain/config/llm/base.py index 4748aa6c..7fa84c94 100644 --- a/embedchain/config/llm/base.py +++ b/embedchain/config/llm/base.py @@ -73,7 +73,7 @@ class BaseLlmConfig(BaseConfig): callbacks: Optional[List] = None, api_key: Optional[str] = None, endpoint: Optional[str] = None, - model_kwargs: Optional[Dict[str, Any]] = {}, + model_kwargs: Optional[Dict[str, Any]] = None, ): """ Initializes a configuration class instance for the LLM. @@ -115,6 +115,8 @@ class BaseLlmConfig(BaseConfig): :type model_kwargs: Optional[Dict[str, Any]], optional :param callbacks: Langchain callback functions to use, defaults to None :type callbacks: Optional[List], optional + :param query_type: The type of query to use, defaults to None + :type query_type: Optional[str], optional :raises ValueError: If the template is not valid as template should contain $context and $query (and optionally $history) :raises ValueError: Stream is not boolean @@ -142,6 +144,7 @@ class BaseLlmConfig(BaseConfig): self.api_key = api_key self.endpoint = endpoint self.model_kwargs = model_kwargs + if type(prompt) is str: prompt = Template(prompt) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 3e5babff..148d1f88 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -7,7 +7,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union from dotenv import load_dotenv from langchain.docstore.document import Document -from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback +from embedchain.cache import ( + adapt, + get_gptcache_session, + gptcache_data_convert, + gptcache_update_cache_callback, +) from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config.base_app_config import BaseAppConfig diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index 4cf13871..ba464da2 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -4,7 +4,9 @@ from typing import Any, Dict, Generator, List, Optional from langchain.schema import BaseMessage as LCBaseMessage from embedchain.config import BaseLlmConfig -from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE +from embedchain.config.llm.base import (DEFAULT_PROMPT, + DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, + DOCS_SITE_PROMPT_TEMPLATE) from embedchain.helpers.json_serializable import JSONSerializable from embedchain.memory.base import ChatHistory from embedchain.memory.message import ChatMessage diff --git a/embedchain/llm/vllm.py b/embedchain/llm/vllm.py new file mode 100644 index 00000000..faac1f39 --- /dev/null +++ b/embedchain/llm/vllm.py @@ -0,0 +1,40 @@ +from typing import Iterable, Optional, Union + +from langchain.callbacks.manager import CallbackManager +from langchain.callbacks.stdout import StdOutCallbackHandler +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain_community.llms import VLLM as BaseVLLM + +from embedchain.config import BaseLlmConfig +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.llm.base import BaseLlm + + +@register_deserializable +class VLLM(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config=config) + if self.config.model is None: + self.config.model = "mosaicml/mpt-7b" + + def get_llm_model_answer(self, prompt): + return self._get_answer(prompt=prompt, config=self.config) + + @staticmethod + def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]: + callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()] + + # Prepare the arguments for BaseVLLM + llm_args = { + "model": config.model, + "temperature": config.temperature, + "top_p": config.top_p, + "callback_manager": CallbackManager(callback_manager), + } + + # Add model_kwargs if they are not None + if config.model_kwargs is not None: + llm_args.update(config.model_kwargs) + + llm = BaseVLLM(**llm_args) + return llm(prompt) diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index d49eabf8..a310a0ae 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -6,7 +6,15 @@ from embedchain.helpers.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB try: - from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusClient, connections, utility + from pymilvus import ( + Collection, + CollectionSchema, + DataType, + FieldSchema, + MilvusClient, + connections, + utility, + ) except ImportError: raise ImportError( "Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`" diff --git a/pyproject.toml b/pyproject.toml index b2d0ab1a..bb958e14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.57" +version = "0.1.58" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ",