diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx
index fa872243..8cf90b74 100644
--- a/docs/components/llms.mdx
+++ b/docs/components/llms.mdx
@@ -22,6 +22,7 @@ Embedchain comes with built-in support for various popular large language models
+
## OpenAI
@@ -654,4 +655,60 @@ llm:
+
+## Groq
+
+[Groq](https://groq.com/) is the creator of the world's first Language Processing Unit (LPU), providing exceptional speed performance for AI workloads running on their LPU Inference Engine.
+
+
+### Usage
+
+In order to use LLMs from Groq, go to their [platform](https://console.groq.com/keys) and get the API key.
+
+Set the API key as `GROQ_API_KEY` environment variable or pass in your app configuration to use the model as given below in the example.
+
+
+
+```python main.py
+import os
+from embedchain import App
+
+# Set your API key here or pass as the environment variable
+groq_api_key = "gsk_xxxx"
+
+config = {
+ "llm": {
+ "provider": "groq",
+ "config": {
+ "model": "mixtral-8x7b-32768",
+ "api_key": groq_api_key,
+ "stream": True
+ }
+ }
+}
+
+app = App.from_config(config=config)
+# Add your data source here
+app.add("https://docs.embedchain.ai/sitemap.xml", data_type="sitemap")
+app.query("Write a poem about Embedchain")
+
+# In the realm of data, vast and wide,
+# Embedchain stands with knowledge as its guide.
+# A platform open, for all to try,
+# Building bots that can truly fly.
+
+# With REST API, data in reach,
+# Deployment a breeze, as easy as a speech.
+# Updating data sources, anytime, anyday,
+# Embedchain's power, never sway.
+
+# A knowledge base, an assistant so grand,
+# Connecting to platforms, near and far.
+# Discord, WhatsApp, Slack, and more,
+# Embedchain's potential, never a bore.
+```
+
+
+
+
diff --git a/embedchain/factory.py b/embedchain/factory.py
index 9a772eaf..3b56b6ca 100644
--- a/embedchain/factory.py
+++ b/embedchain/factory.py
@@ -23,6 +23,7 @@ class LlmFactory:
"google": "embedchain.llm.google.GoogleLlm",
"aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm",
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
+ "groq": "embedchain.llm.groq.GroqLlm",
}
provider_to_config_class = {
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
diff --git a/embedchain/llm/groq.py b/embedchain/llm/groq.py
new file mode 100644
index 00000000..d658be15
--- /dev/null
+++ b/embedchain/llm/groq.py
@@ -0,0 +1,43 @@
+import os
+from typing import Optional
+
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+from langchain.schema import HumanMessage, SystemMessage
+
+try:
+ from langchain_groq import ChatGroq
+except ImportError:
+ raise ImportError("Groq requires extra dependencies. Install with `pip install langchain-groq`") from None
+
+
+from embedchain.config import BaseLlmConfig
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.llm.base import BaseLlm
+
+
+@register_deserializable
+class GroqLlm(BaseLlm):
+ def __init__(self, config: Optional[BaseLlmConfig] = None):
+ super().__init__(config=config)
+
+ def get_llm_model_answer(self, prompt) -> str:
+ response = self._get_answer(prompt, self.config)
+ return response
+
+ def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
+ messages = []
+ if config.system_prompt:
+ messages.append(SystemMessage(content=config.system_prompt))
+ messages.append(HumanMessage(content=prompt))
+ api_key = config.api_key or os.environ["GROQ_API_KEY"]
+ kwargs = {
+ "model_name": config.model or "mixtral-8x7b-32768",
+ "temperature": config.temperature,
+ "groq_api_key": api_key,
+ }
+ if config.stream:
+ callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
+ chat = ChatGroq(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
+ else:
+ chat = ChatGroq(**kwargs)
+ return chat.invoke(messages).content
diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py
index af6e0a3f..709d2262 100644
--- a/embedchain/llm/openai.py
+++ b/embedchain/llm/openai.py
@@ -58,8 +58,7 @@ class OpenAILlm(BaseLlm):
messages: list[BaseMessage],
) -> str:
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
- from langchain_core.utils.function_calling import \
- convert_to_openai_tool
+ from langchain_core.utils.function_calling import convert_to_openai_tool
openai_tools = [convert_to_openai_tool(tools)]
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py
index 55fe2cc8..ce7b5115 100644
--- a/embedchain/utils/misc.py
+++ b/embedchain/utils/misc.py
@@ -406,9 +406,11 @@ def validate_config(config_data):
"aws_bedrock",
"mistralai",
"vllm",
+ "groq",
),
Optional("config"): {
Optional("model"): str,
+ Optional("model_name"): str,
Optional("number_documents"): int,
Optional("temperature"): float,
Optional("max_tokens"): int,
diff --git a/pyproject.toml b/pyproject.toml
index ffbba155..9c9fb780 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "embedchain"
-version = "0.1.85"
+version = "0.1.86"
description = "Simplest open source retrieval(RAG) framework"
authors = [
"Taranjeet Singh ",