diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx index fa872243..8cf90b74 100644 --- a/docs/components/llms.mdx +++ b/docs/components/llms.mdx @@ -22,6 +22,7 @@ Embedchain comes with built-in support for various popular large language models + ## OpenAI @@ -654,4 +655,60 @@ llm:
+ +## Groq + +[Groq](https://groq.com/) is the creator of the world's first Language Processing Unit (LPU), providing exceptional speed performance for AI workloads running on their LPU Inference Engine. + + +### Usage + +In order to use LLMs from Groq, go to their [platform](https://console.groq.com/keys) and get the API key. + +Set the API key as `GROQ_API_KEY` environment variable or pass in your app configuration to use the model as given below in the example. + + + +```python main.py +import os +from embedchain import App + +# Set your API key here or pass as the environment variable +groq_api_key = "gsk_xxxx" + +config = { + "llm": { + "provider": "groq", + "config": { + "model": "mixtral-8x7b-32768", + "api_key": groq_api_key, + "stream": True + } + } +} + +app = App.from_config(config=config) +# Add your data source here +app.add("https://docs.embedchain.ai/sitemap.xml", data_type="sitemap") +app.query("Write a poem about Embedchain") + +# In the realm of data, vast and wide, +# Embedchain stands with knowledge as its guide. +# A platform open, for all to try, +# Building bots that can truly fly. + +# With REST API, data in reach, +# Deployment a breeze, as easy as a speech. +# Updating data sources, anytime, anyday, +# Embedchain's power, never sway. + +# A knowledge base, an assistant so grand, +# Connecting to platforms, near and far. +# Discord, WhatsApp, Slack, and more, +# Embedchain's potential, never a bore. +``` + + +
+ diff --git a/embedchain/factory.py b/embedchain/factory.py index 9a772eaf..3b56b6ca 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -23,6 +23,7 @@ class LlmFactory: "google": "embedchain.llm.google.GoogleLlm", "aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm", "mistralai": "embedchain.llm.mistralai.MistralAILlm", + "groq": "embedchain.llm.groq.GroqLlm", } provider_to_config_class = { "embedchain": "embedchain.config.llm.base.BaseLlmConfig", diff --git a/embedchain/llm/groq.py b/embedchain/llm/groq.py new file mode 100644 index 00000000..d658be15 --- /dev/null +++ b/embedchain/llm/groq.py @@ -0,0 +1,43 @@ +import os +from typing import Optional + +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.schema import HumanMessage, SystemMessage + +try: + from langchain_groq import ChatGroq +except ImportError: + raise ImportError("Groq requires extra dependencies. Install with `pip install langchain-groq`") from None + + +from embedchain.config import BaseLlmConfig +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.llm.base import BaseLlm + + +@register_deserializable +class GroqLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config=config) + + def get_llm_model_answer(self, prompt) -> str: + response = self._get_answer(prompt, self.config) + return response + + def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str: + messages = [] + if config.system_prompt: + messages.append(SystemMessage(content=config.system_prompt)) + messages.append(HumanMessage(content=prompt)) + api_key = config.api_key or os.environ["GROQ_API_KEY"] + kwargs = { + "model_name": config.model or "mixtral-8x7b-32768", + "temperature": config.temperature, + "groq_api_key": api_key, + } + if config.stream: + callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()] + chat = ChatGroq(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key) + else: + chat = ChatGroq(**kwargs) + return chat.invoke(messages).content diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index af6e0a3f..709d2262 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -58,8 +58,7 @@ class OpenAILlm(BaseLlm): messages: list[BaseMessage], ) -> str: from langchain.output_parsers.openai_tools import JsonOutputToolsParser - from langchain_core.utils.function_calling import \ - convert_to_openai_tool + from langchain_core.utils.function_calling import convert_to_openai_tool openai_tools = [convert_to_openai_tool(tools)] chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser()) diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 55fe2cc8..ce7b5115 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -406,9 +406,11 @@ def validate_config(config_data): "aws_bedrock", "mistralai", "vllm", + "groq", ), Optional("config"): { Optional("model"): str, + Optional("model_name"): str, Optional("number_documents"): int, Optional("temperature"): float, Optional("max_tokens"): int, diff --git a/pyproject.toml b/pyproject.toml index ffbba155..9c9fb780 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.85" +version = "0.1.86" description = "Simplest open source retrieval(RAG) framework" authors = [ "Taranjeet Singh ",