diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx
index bbf05a36..b9b4b138 100644
--- a/docs/components/llms.mdx
+++ b/docs/components/llms.mdx
@@ -20,6 +20,7 @@ Embedchain comes with built-in support for various popular large language models
+
## OpenAI
@@ -620,5 +621,47 @@ llm:
```
+
+## Mistral AI
+
+Obtain the Mistral AI api key from their [console](https://console.mistral.ai/).
+
+
+
+```python main.py
+import os
+from embedchain import App
+
+os.environ["MISTRAL_API_KEY"] = "xxx"
+
+app = App.from_config(config_path="config.yaml")
+
+app.add("https://www.forbes.com/profile/elon-musk")
+
+response = app.query("what is the net worth of Elon Musk?")
+# As of January 16, 2024, Elon Musk's net worth is $225.4 billion.
+
+response = app.chat("which companies does elon own?")
+# Elon Musk owns Tesla, SpaceX, Boring Company, Twitter, and X.
+
+response = app.chat("what question did I ask you already?")
+# You have asked me several times already which companies Elon Musk owns, specifically Tesla, SpaceX, Boring Company, Twitter, and X.
+```
+
+```yaml config.yaml
+llm:
+ provider: mistralai
+ config:
+ model: mistral-tiny
+ temperature: 0.5
+ max_tokens: 1000
+ top_p: 1
+embedder:
+ provider: mistralai
+ config:
+ model: mistral-embed
+```
+
+
diff --git a/embedchain/embedder/mistralai.py b/embedchain/embedder/mistralai.py
new file mode 100644
index 00000000..29db72ae
--- /dev/null
+++ b/embedchain/embedder/mistralai.py
@@ -0,0 +1,46 @@
+import os
+from typing import Optional, Union
+
+from chromadb import EmbeddingFunction, Embeddings
+
+from embedchain.config import BaseEmbedderConfig
+from embedchain.embedder.base import BaseEmbedder
+from embedchain.models import VectorDimensions
+
+
+class MistralAIEmbeddingFunction(EmbeddingFunction):
+ def __init__(self, config: BaseEmbedderConfig) -> None:
+ super().__init__()
+ try:
+ from langchain_mistralai import MistralAIEmbeddings
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(
+ "The required dependencies for MistralAI are not installed."
+ 'Please install with `pip install --upgrade "embedchain[mistralai]"`'
+ ) from None
+ self.config = config
+ api_key = self.config.api_key or os.getenv("MISTRAL_API_KEY")
+ self.client = MistralAIEmbeddings(mistral_api_key=api_key)
+ self.client.model = self.config.model
+
+ def __call__(self, input: Union[list[str], str]) -> Embeddings:
+ if isinstance(input, str):
+ input_ = [input]
+ else:
+ input_ = input
+ response = self.client.embed_documents(input_)
+ return response
+
+
+class MistralAIEmbedder(BaseEmbedder):
+ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
+ super().__init__(config)
+
+ if self.config.model is None:
+ self.config.model = "mistral-embed"
+
+ embedding_fn = MistralAIEmbeddingFunction(config=self.config)
+ self.set_embedding_fn(embedding_fn=embedding_fn)
+
+ vector_dimension = self.config.vector_dimension or VectorDimensions.MISTRAL_AI.value
+ self.set_vector_dimension(vector_dimension=vector_dimension)
diff --git a/embedchain/factory.py b/embedchain/factory.py
index 9fe9fb5b..04b20c56 100644
--- a/embedchain/factory.py
+++ b/embedchain/factory.py
@@ -21,6 +21,7 @@ class LlmFactory:
"openai": "embedchain.llm.openai.OpenAILlm",
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
"google": "embedchain.llm.google.GoogleLlm",
+ "mistralai": "embedchain.llm.mistralai.MistralAILlm",
}
provider_to_config_class = {
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
@@ -50,6 +51,7 @@ class EmbedderFactory:
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
"google": "embedchain.embedder.google.GoogleAIEmbedder",
+ "mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
}
provider_to_config_class = {
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
diff --git a/embedchain/llm/mistralai.py b/embedchain/llm/mistralai.py
new file mode 100644
index 00000000..6f2366bc
--- /dev/null
+++ b/embedchain/llm/mistralai.py
@@ -0,0 +1,52 @@
+import os
+from typing import Optional
+
+from embedchain.config import BaseLlmConfig
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.llm.base import BaseLlm
+
+
+@register_deserializable
+class MistralAILlm(BaseLlm):
+ def __init__(self, config: Optional[BaseLlmConfig] = None):
+ super().__init__(config)
+ if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
+ raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
+
+ def get_llm_model_answer(self, prompt):
+ return MistralAILlm._get_answer(prompt=prompt, config=self.config)
+
+ @staticmethod
+ def _get_answer(prompt: str, config: BaseLlmConfig):
+ try:
+ from langchain_core.messages import HumanMessage, SystemMessage
+ from langchain_mistralai.chat_models import ChatMistralAI
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(
+ "The required dependencies for MistralAI are not installed."
+ 'Please install with `pip install --upgrade "embedchain[mistralai]"`'
+ ) from None
+
+ api_key = config.api_key or os.getenv("MISTRAL_API_KEY")
+ client = ChatMistralAI(mistral_api_key=api_key)
+ messages = []
+ if config.system_prompt:
+ messages.append(SystemMessage(content=config.system_prompt))
+ messages.append(HumanMessage(content=prompt))
+ kwargs = {
+ "model": config.model or "mistral-tiny",
+ "temperature": config.temperature,
+ "max_tokens": config.max_tokens,
+ "top_p": config.top_p,
+ }
+
+ # TODO: Add support for streaming
+ if config.stream:
+ answer = ""
+ for chunk in client.stream(**kwargs, input=messages):
+ answer += chunk.content
+ return answer
+ else:
+ response = client.invoke(**kwargs, input=messages)
+ answer = response.content
+ return answer
diff --git a/embedchain/models/vector_dimensions.py b/embedchain/models/vector_dimensions.py
index 2bdaa0fa..e274d82f 100644
--- a/embedchain/models/vector_dimensions.py
+++ b/embedchain/models/vector_dimensions.py
@@ -8,3 +8,4 @@ class VectorDimensions(Enum):
VERTEX_AI = 768
HUGGING_FACE = 384
GOOGLE_AI = 768
+ MISTRAL_AI = 1024
diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py
index c08665c2..6cc13ade 100644
--- a/embedchain/utils/misc.py
+++ b/embedchain/utils/misc.py
@@ -406,6 +406,7 @@ def validate_config(config_data):
"llama2",
"vertexai",
"google",
+ "mistralai",
),
Optional("config"): {
Optional("model"): str,
@@ -431,7 +432,15 @@ def validate_config(config_data):
Optional("config"): object, # TODO: add particular config schema for each provider
},
Optional("embedder"): {
- Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
+ Optional("provider"): Or(
+ "openai",
+ "gpt4all",
+ "huggingface",
+ "vertexai",
+ "azure_openai",
+ "google",
+ "mistralai",
+ ),
Optional("config"): {
Optional("model"): Optional(str),
Optional("deployment_name"): Optional(str),
@@ -442,7 +451,15 @@ def validate_config(config_data):
},
},
Optional("embedding_model"): {
- Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
+ Optional("provider"): Or(
+ "openai",
+ "gpt4all",
+ "huggingface",
+ "vertexai",
+ "azure_openai",
+ "google",
+ "mistralai",
+ ),
Optional("config"): {
Optional("model"): str,
Optional("deployment_name"): str,
diff --git a/poetry.lock b/poetry.lock
index 0399cf16..b2599536 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2487,24 +2487,24 @@ files = [
[[package]]
name = "httpcore"
-version = "0.18.0"
+version = "1.0.2"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.8"
files = [
- {file = "httpcore-0.18.0-py3-none-any.whl", hash = "sha256:adc5398ee0a476567bf87467063ee63584a8bce86078bf748e48754f60202ced"},
- {file = "httpcore-0.18.0.tar.gz", hash = "sha256:13b5e5cd1dca1a6636a6aaea212b19f4f85cd88c366a2b82304181b769aab3c9"},
+ {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"},
+ {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"},
]
[package.dependencies]
-anyio = ">=3.0,<5.0"
certifi = "*"
h11 = ">=0.13,<0.15"
-sniffio = "==1.*"
[package.extras]
+asyncio = ["anyio (>=4.0,<5.0)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
+trio = ["trio (>=0.22.0,<0.23.0)"]
[[package]]
name = "httplib2"
@@ -2569,21 +2569,22 @@ test = ["Cython (>=0.29.24,<0.30.0)"]
[[package]]
name = "httpx"
-version = "0.25.0"
+version = "0.25.2"
description = "The next generation HTTP client."
optional = false
python-versions = ">=3.8"
files = [
- {file = "httpx-0.25.0-py3-none-any.whl", hash = "sha256:181ea7f8ba3a82578be86ef4171554dd45fec26a02556a744db029a0a27b7100"},
- {file = "httpx-0.25.0.tar.gz", hash = "sha256:47ecda285389cb32bb2691cc6e069e3ab0205956f681c5b2ad2325719751d875"},
+ {file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"},
+ {file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"},
]
[package.dependencies]
+anyio = "*"
brotli = {version = "*", optional = true, markers = "platform_python_implementation == \"CPython\" and extra == \"brotli\""}
brotlicffi = {version = "*", optional = true, markers = "platform_python_implementation != \"CPython\" and extra == \"brotli\""}
certifi = "*"
h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""}
-httpcore = ">=0.18.0,<0.19.0"
+httpcore = "==1.*"
idna = "*"
sniffio = "*"
socksio = {version = "==1.*", optional = true, markers = "extra == \"socks\""}
@@ -3024,6 +3025,45 @@ openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"]
qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"]
text-helpers = ["chardet (>=5.1.0,<6.0.0)"]
+[[package]]
+name = "langchain-core"
+version = "0.1.12"
+description = "Building applications with LLMs through composability"
+optional = true
+python-versions = ">=3.8.1,<4.0"
+files = [
+ {file = "langchain_core-0.1.12-py3-none-any.whl", hash = "sha256:d11c6262f7a9deff7de8fdf14498b8a951020dfed3a80f2358ab731ad04abef0"},
+ {file = "langchain_core-0.1.12.tar.gz", hash = "sha256:f18e9300e9a07589b3e280e51befbc5a4513f535949406e55eb7a2dc40c3ce66"},
+]
+
+[package.dependencies]
+anyio = ">=3,<5"
+jsonpatch = ">=1.33,<2.0"
+langsmith = ">=0.0.63,<0.1.0"
+packaging = ">=23.2,<24.0"
+pydantic = ">=1,<3"
+PyYAML = ">=5.3"
+requests = ">=2,<3"
+tenacity = ">=8.1.0,<9.0.0"
+
+[package.extras]
+extended-testing = ["jinja2 (>=3,<4)"]
+
+[[package]]
+name = "langchain-mistralai"
+version = "0.0.3"
+description = "An integration package connecting Mistral and LangChain"
+optional = true
+python-versions = ">=3.8.1,<4.0"
+files = [
+ {file = "langchain_mistralai-0.0.3-py3-none-any.whl", hash = "sha256:ebb8ba3d7978b5ee16f7e09512ffa434e00bc9863f1537f1a5f5203882d99619"},
+ {file = "langchain_mistralai-0.0.3.tar.gz", hash = "sha256:2e45ee0118df8e4b5577ce8c4f89743059801e473f40a8b7c89cb99dd715f423"},
+]
+
+[package.dependencies]
+langchain-core = ">=0.1,<0.2"
+mistralai = ">=0.0.11,<0.0.12"
+
[[package]]
name = "langdetect"
version = "1.0.9"
@@ -3458,6 +3498,22 @@ files = [
certifi = "*"
urllib3 = "*"
+[[package]]
+name = "mistralai"
+version = "0.0.11"
+description = ""
+optional = true
+python-versions = ">=3.8,<4.0"
+files = [
+ {file = "mistralai-0.0.11-py3-none-any.whl", hash = "sha256:fb2a240a3985420c4e7db48eb5077d6d6dbc5e83cac0dd948c20342fb48087ee"},
+ {file = "mistralai-0.0.11.tar.gz", hash = "sha256:383072715531198305dab829ab3749b64933bbc2549354f3c9ebc43c17b912cf"},
+]
+
+[package.dependencies]
+httpx = ">=0.25.2,<0.26.0"
+orjson = ">=3.9.10,<4.0.0"
+pydantic = ">=2.5.2,<3.0.0"
+
[[package]]
name = "mock"
version = "5.1.0"
@@ -4294,6 +4350,65 @@ files = [
{file = "opentelemetry_semantic_conventions-0.42b0.tar.gz", hash = "sha256:44ae67a0a3252a05072877857e5cc1242c98d4cf12870159f1a94bec800d38ec"},
]
+[[package]]
+name = "orjson"
+version = "3.9.12"
+description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "orjson-3.9.12-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6b4e2bed7d00753c438e83b613923afdd067564ff7ed696bfe3a7b073a236e07"},
+ {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd1b8ec63f0bf54a50b498eedeccdca23bd7b658f81c524d18e410c203189365"},
+ {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ab8add018a53665042a5ae68200f1ad14c7953fa12110d12d41166f111724656"},
+ {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12756a108875526b76e505afe6d6ba34960ac6b8c5ec2f35faf73ef161e97e07"},
+ {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:890e7519c0c70296253660455f77e3a194554a3c45e42aa193cdebc76a02d82b"},
+ {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d664880d7f016efbae97c725b243b33c2cbb4851ddc77f683fd1eec4a7894146"},
+ {file = "orjson-3.9.12-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:cfdaede0fa5b500314ec7b1249c7e30e871504a57004acd116be6acdda3b8ab3"},
+ {file = "orjson-3.9.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6492ff5953011e1ba9ed1bf086835fd574bd0a3cbe252db8e15ed72a30479081"},
+ {file = "orjson-3.9.12-cp310-none-win32.whl", hash = "sha256:29bf08e2eadb2c480fdc2e2daae58f2f013dff5d3b506edd1e02963b9ce9f8a9"},
+ {file = "orjson-3.9.12-cp310-none-win_amd64.whl", hash = "sha256:0fc156fba60d6b50743337ba09f052d8afc8b64595112996d22f5fce01ab57da"},
+ {file = "orjson-3.9.12-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:2849f88a0a12b8d94579b67486cbd8f3a49e36a4cb3d3f0ab352c596078c730c"},
+ {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3186b18754befa660b31c649a108a915493ea69b4fc33f624ed854ad3563ac65"},
+ {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbbf313c9fb9d4f6cf9c22ced4b6682230457741daeb3d7060c5d06c2e73884a"},
+ {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99e8cd005b3926c3db9b63d264bd05e1bf4451787cc79a048f27f5190a9a0311"},
+ {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59feb148392d9155f3bfed0a2a3209268e000c2c3c834fb8fe1a6af9392efcbf"},
+ {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4ae815a172a1f073b05b9e04273e3b23e608a0858c4e76f606d2d75fcabde0c"},
+ {file = "orjson-3.9.12-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ed398f9a9d5a1bf55b6e362ffc80ac846af2122d14a8243a1e6510a4eabcb71e"},
+ {file = "orjson-3.9.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d3cfb76600c5a1e6be91326b8f3b83035a370e727854a96d801c1ea08b708073"},
+ {file = "orjson-3.9.12-cp311-none-win32.whl", hash = "sha256:a2b6f5252c92bcab3b742ddb3ac195c0fa74bed4319acd74f5d54d79ef4715dc"},
+ {file = "orjson-3.9.12-cp311-none-win_amd64.whl", hash = "sha256:c95488e4aa1d078ff5776b58f66bd29d628fa59adcb2047f4efd3ecb2bd41a71"},
+ {file = "orjson-3.9.12-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d6ce2062c4af43b92b0221ed4f445632c6bf4213f8a7da5396a122931377acd9"},
+ {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:950951799967558c214cd6cceb7ceceed6f81d2c3c4135ee4a2c9c69f58aa225"},
+ {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2dfaf71499d6fd4153f5c86eebb68e3ec1bf95851b030a4b55c7637a37bbdee4"},
+ {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:659a8d7279e46c97661839035a1a218b61957316bf0202674e944ac5cfe7ed83"},
+ {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af17fa87bccad0b7f6fd8ac8f9cbc9ee656b4552783b10b97a071337616db3e4"},
+ {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd52dec9eddf4c8c74392f3fd52fa137b5f2e2bed1d9ae958d879de5f7d7cded"},
+ {file = "orjson-3.9.12-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:640e2b5d8e36b970202cfd0799d11a9a4ab46cf9212332cd642101ec952df7c8"},
+ {file = "orjson-3.9.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:daa438bd8024e03bcea2c5a92cd719a663a58e223fba967296b6ab9992259dbf"},
+ {file = "orjson-3.9.12-cp312-none-win_amd64.whl", hash = "sha256:1bb8f657c39ecdb924d02e809f992c9aafeb1ad70127d53fb573a6a6ab59d549"},
+ {file = "orjson-3.9.12-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:f4098c7674901402c86ba6045a551a2ee345f9f7ed54eeffc7d86d155c8427e5"},
+ {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5586a533998267458fad3a457d6f3cdbddbcce696c916599fa8e2a10a89b24d3"},
+ {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:54071b7398cd3f90e4bb61df46705ee96cb5e33e53fc0b2f47dbd9b000e238e1"},
+ {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:67426651faa671b40443ea6f03065f9c8e22272b62fa23238b3efdacd301df31"},
+ {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4a0cd56e8ee56b203abae7d482ac0d233dbfb436bb2e2d5cbcb539fe1200a312"},
+ {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a84a0c3d4841a42e2571b1c1ead20a83e2792644c5827a606c50fc8af7ca4bee"},
+ {file = "orjson-3.9.12-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:09d60450cda3fa6c8ed17770c3a88473a16460cd0ff2ba74ef0df663b6fd3bb8"},
+ {file = "orjson-3.9.12-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bc82a4db9934a78ade211cf2e07161e4f068a461c1796465d10069cb50b32a80"},
+ {file = "orjson-3.9.12-cp38-none-win32.whl", hash = "sha256:61563d5d3b0019804d782137a4f32c72dc44c84e7d078b89d2d2a1adbaa47b52"},
+ {file = "orjson-3.9.12-cp38-none-win_amd64.whl", hash = "sha256:410f24309fbbaa2fab776e3212a81b96a1ec6037259359a32ea79fbccfcf76aa"},
+ {file = "orjson-3.9.12-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e773f251258dd82795fd5daeac081d00b97bacf1548e44e71245543374874bcf"},
+ {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b159baecfda51c840a619948c25817d37733a4d9877fea96590ef8606468b362"},
+ {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:975e72e81a249174840d5a8df977d067b0183ef1560a32998be340f7e195c730"},
+ {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:06e42e899dde61eb1851a9fad7f1a21b8e4be063438399b63c07839b57668f6c"},
+ {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c157e999e5694475a5515942aebeed6e43f7a1ed52267c1c93dcfde7d78d421"},
+ {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dde1bc7c035f2d03aa49dc8642d9c6c9b1a81f2470e02055e76ed8853cfae0c3"},
+ {file = "orjson-3.9.12-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b0e9d73cdbdad76a53a48f563447e0e1ce34bcecef4614eb4b146383e6e7d8c9"},
+ {file = "orjson-3.9.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:96e44b21fe407b8ed48afbb3721f3c8c8ce17e345fbe232bd4651ace7317782d"},
+ {file = "orjson-3.9.12-cp39-none-win32.whl", hash = "sha256:cbd0f3555205bf2a60f8812133f2452d498dbefa14423ba90fe89f32276f7abf"},
+ {file = "orjson-3.9.12-cp39-none-win_amd64.whl", hash = "sha256:03ea7ee7e992532c2f4a06edd7ee1553f0644790553a118e003e3c405add41fa"},
+ {file = "orjson-3.9.12.tar.gz", hash = "sha256:da908d23a3b3243632b523344403b128722a5f45e278a8343c2bb67538dff0e4"},
+]
+
[[package]]
name = "overrides"
version = "7.4.0"
@@ -8110,6 +8225,7 @@ googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-
huggingface-hub = ["huggingface_hub"]
llama2 = ["replicate"]
milvus = ["pymilvus"]
+mistralai = ["langchain-mistralai"]
modal = ["modal"]
mysql = ["mysql-connector-python"]
opensearch = ["opensearch-py"]
@@ -8130,4 +8246,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.12"
-content-hash = "02bd85e14374a9dc9b59523b8fb4baea7068251976ba7f87722cac94a9974ccc"
+content-hash = "cb0da55af7c61300bb321770ed319c900b6b3ba3865421d63eb9120beb73d06c"
diff --git a/pyproject.toml b/pyproject.toml
index 01e9bb5d..b7d9059b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -149,6 +149,7 @@ google-auth-oauthlib = { version = "^1.2.0", optional = true }
google-auth = { version = "^2.25.2", optional = true }
google-auth-httplib2 = { version = "^0.2.0", optional = true }
google-api-core = { version = "^2.15.0", optional = true }
+langchain-mistralai = { version = "^0.0.3", optional = true }
[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
@@ -214,6 +215,7 @@ rss_feed = [
google = ["google-generativeai"]
modal = ["modal"]
dropbox = ["dropbox"]
+mistralai = ["langchain-mistralai"]
[tool.poetry.group.docs.dependencies]
diff --git a/tests/llm/test_mistralai.py b/tests/llm/test_mistralai.py
new file mode 100644
index 00000000..c99999c8
--- /dev/null
+++ b/tests/llm/test_mistralai.py
@@ -0,0 +1,60 @@
+import pytest
+
+from embedchain.config import BaseLlmConfig
+from embedchain.llm.mistralai import MistralAILlm
+
+
+@pytest.fixture
+def mistralai_llm_config(monkeypatch):
+ monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
+ yield BaseLlmConfig(model="mistral-tiny", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
+ monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
+
+
+def test_mistralai_llm_init_missing_api_key(monkeypatch):
+ monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
+ with pytest.raises(ValueError, match="Please set the MISTRAL_API_KEY environment variable."):
+ MistralAILlm()
+
+
+def test_mistralai_llm_init(monkeypatch):
+ monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
+ llm = MistralAILlm()
+ assert llm is not None
+
+
+def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
+ def mock_get_answer(prompt, config):
+ return "Generated Text"
+
+ monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
+ llm = MistralAILlm(config=mistralai_llm_config)
+ result = llm.get_llm_model_answer("test prompt")
+
+ assert result == "Generated Text"
+
+
+def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
+ mistralai_llm_config.system_prompt = "Test system prompt"
+ monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+ llm = MistralAILlm(config=mistralai_llm_config)
+ result = llm.get_llm_model_answer("test prompt")
+
+ assert result == "Generated Text"
+
+
+def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
+ monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+ llm = MistralAILlm(config=mistralai_llm_config)
+ result = llm.get_llm_model_answer("")
+
+ assert result == "Generated Text"
+
+
+def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
+ mistralai_llm_config.system_prompt = None
+ monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
+ llm = MistralAILlm(config=mistralai_llm_config)
+ result = llm.get_llm_model_answer("test prompt")
+
+ assert result == "Generated Text"