[Feature] Add support for Mistral API (#1194)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2024-01-20 12:31:50 +05:30
committed by GitHub
parent 9afc6878c8
commit cb0499407e
9 changed files with 351 additions and 12 deletions

View File

@@ -20,6 +20,7 @@ Embedchain comes with built-in support for various popular large language models
<Card title="Hugging Face" href="#hugging-face"></Card>
<Card title="Llama2" href="#llama2"></Card>
<Card title="Vertex AI" href="#vertex-ai"></Card>
<Card title="Mistral AI" href="#mistral-ai"></Card>
</CardGroup>
## OpenAI
@@ -620,5 +621,47 @@ llm:
```
</CodeGroup>
## Mistral AI
Obtain the Mistral AI api key from their [console](https://console.mistral.ai/).
<CodeGroup>
```python main.py
import os
from embedchain import App
os.environ["MISTRAL_API_KEY"] = "xxx"
app = App.from_config(config_path="config.yaml")
app.add("https://www.forbes.com/profile/elon-musk")
response = app.query("what is the net worth of Elon Musk?")
# As of January 16, 2024, Elon Musk's net worth is $225.4 billion.
response = app.chat("which companies does elon own?")
# Elon Musk owns Tesla, SpaceX, Boring Company, Twitter, and X.
response = app.chat("what question did I ask you already?")
# You have asked me several times already which companies Elon Musk owns, specifically Tesla, SpaceX, Boring Company, Twitter, and X.
```
```yaml config.yaml
llm:
provider: mistralai
config:
model: mistral-tiny
temperature: 0.5
max_tokens: 1000
top_p: 1
embedder:
provider: mistralai
config:
model: mistral-embed
```
</CodeGroup>
<br/ >
<Snippet file="missing-llm-tip.mdx" />

View File

@@ -0,0 +1,46 @@
import os
from typing import Optional, Union
from chromadb import EmbeddingFunction, Embeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class MistralAIEmbeddingFunction(EmbeddingFunction):
def __init__(self, config: BaseEmbedderConfig) -> None:
super().__init__()
try:
from langchain_mistralai import MistralAIEmbeddings
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for MistralAI are not installed."
'Please install with `pip install --upgrade "embedchain[mistralai]"`'
) from None
self.config = config
api_key = self.config.api_key or os.getenv("MISTRAL_API_KEY")
self.client = MistralAIEmbeddings(mistral_api_key=api_key)
self.client.model = self.config.model
def __call__(self, input: Union[list[str], str]) -> Embeddings:
if isinstance(input, str):
input_ = [input]
else:
input_ = input
response = self.client.embed_documents(input_)
return response
class MistralAIEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
if self.config.model is None:
self.config.model = "mistral-embed"
embedding_fn = MistralAIEmbeddingFunction(config=self.config)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = self.config.vector_dimension or VectorDimensions.MISTRAL_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -21,6 +21,7 @@ class LlmFactory:
"openai": "embedchain.llm.openai.OpenAILlm",
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
"google": "embedchain.llm.google.GoogleLlm",
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
}
provider_to_config_class = {
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
@@ -50,6 +51,7 @@ class EmbedderFactory:
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
"google": "embedchain.embedder.google.GoogleAIEmbedder",
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
}
provider_to_config_class = {
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",

View File

@@ -0,0 +1,52 @@
import os
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@register_deserializable
class MistralAILlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt):
return MistralAILlm._get_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig):
try:
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_mistralai.chat_models import ChatMistralAI
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for MistralAI are not installed."
'Please install with `pip install --upgrade "embedchain[mistralai]"`'
) from None
api_key = config.api_key or os.getenv("MISTRAL_API_KEY")
client = ChatMistralAI(mistral_api_key=api_key)
messages = []
if config.system_prompt:
messages.append(SystemMessage(content=config.system_prompt))
messages.append(HumanMessage(content=prompt))
kwargs = {
"model": config.model or "mistral-tiny",
"temperature": config.temperature,
"max_tokens": config.max_tokens,
"top_p": config.top_p,
}
# TODO: Add support for streaming
if config.stream:
answer = ""
for chunk in client.stream(**kwargs, input=messages):
answer += chunk.content
return answer
else:
response = client.invoke(**kwargs, input=messages)
answer = response.content
return answer

View File

@@ -8,3 +8,4 @@ class VectorDimensions(Enum):
VERTEX_AI = 768
HUGGING_FACE = 384
GOOGLE_AI = 768
MISTRAL_AI = 1024

View File

@@ -406,6 +406,7 @@ def validate_config(config_data):
"llama2",
"vertexai",
"google",
"mistralai",
),
Optional("config"): {
Optional("model"): str,
@@ -431,7 +432,15 @@ def validate_config(config_data):
Optional("config"): object, # TODO: add particular config schema for each provider
},
Optional("embedder"): {
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
Optional("provider"): Or(
"openai",
"gpt4all",
"huggingface",
"vertexai",
"azure_openai",
"google",
"mistralai",
),
Optional("config"): {
Optional("model"): Optional(str),
Optional("deployment_name"): Optional(str),
@@ -442,7 +451,15 @@ def validate_config(config_data):
},
},
Optional("embedding_model"): {
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
Optional("provider"): Or(
"openai",
"gpt4all",
"huggingface",
"vertexai",
"azure_openai",
"google",
"mistralai",
),
Optional("config"): {
Optional("model"): str,
Optional("deployment_name"): str,

136
poetry.lock generated
View File

@@ -2487,24 +2487,24 @@ files = [
[[package]]
name = "httpcore"
version = "0.18.0"
version = "1.0.2"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpcore-0.18.0-py3-none-any.whl", hash = "sha256:adc5398ee0a476567bf87467063ee63584a8bce86078bf748e48754f60202ced"},
{file = "httpcore-0.18.0.tar.gz", hash = "sha256:13b5e5cd1dca1a6636a6aaea212b19f4f85cd88c366a2b82304181b769aab3c9"},
{file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"},
{file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"},
]
[package.dependencies]
anyio = ">=3.0,<5.0"
certifi = "*"
h11 = ">=0.13,<0.15"
sniffio = "==1.*"
[package.extras]
asyncio = ["anyio (>=4.0,<5.0)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<0.23.0)"]
[[package]]
name = "httplib2"
@@ -2569,21 +2569,22 @@ test = ["Cython (>=0.29.24,<0.30.0)"]
[[package]]
name = "httpx"
version = "0.25.0"
version = "0.25.2"
description = "The next generation HTTP client."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpx-0.25.0-py3-none-any.whl", hash = "sha256:181ea7f8ba3a82578be86ef4171554dd45fec26a02556a744db029a0a27b7100"},
{file = "httpx-0.25.0.tar.gz", hash = "sha256:47ecda285389cb32bb2691cc6e069e3ab0205956f681c5b2ad2325719751d875"},
{file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"},
{file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"},
]
[package.dependencies]
anyio = "*"
brotli = {version = "*", optional = true, markers = "platform_python_implementation == \"CPython\" and extra == \"brotli\""}
brotlicffi = {version = "*", optional = true, markers = "platform_python_implementation != \"CPython\" and extra == \"brotli\""}
certifi = "*"
h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""}
httpcore = ">=0.18.0,<0.19.0"
httpcore = "==1.*"
idna = "*"
sniffio = "*"
socksio = {version = "==1.*", optional = true, markers = "extra == \"socks\""}
@@ -3024,6 +3025,45 @@ openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"]
qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"]
text-helpers = ["chardet (>=5.1.0,<6.0.0)"]
[[package]]
name = "langchain-core"
version = "0.1.12"
description = "Building applications with LLMs through composability"
optional = true
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langchain_core-0.1.12-py3-none-any.whl", hash = "sha256:d11c6262f7a9deff7de8fdf14498b8a951020dfed3a80f2358ab731ad04abef0"},
{file = "langchain_core-0.1.12.tar.gz", hash = "sha256:f18e9300e9a07589b3e280e51befbc5a4513f535949406e55eb7a2dc40c3ce66"},
]
[package.dependencies]
anyio = ">=3,<5"
jsonpatch = ">=1.33,<2.0"
langsmith = ">=0.0.63,<0.1.0"
packaging = ">=23.2,<24.0"
pydantic = ">=1,<3"
PyYAML = ">=5.3"
requests = ">=2,<3"
tenacity = ">=8.1.0,<9.0.0"
[package.extras]
extended-testing = ["jinja2 (>=3,<4)"]
[[package]]
name = "langchain-mistralai"
version = "0.0.3"
description = "An integration package connecting Mistral and LangChain"
optional = true
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langchain_mistralai-0.0.3-py3-none-any.whl", hash = "sha256:ebb8ba3d7978b5ee16f7e09512ffa434e00bc9863f1537f1a5f5203882d99619"},
{file = "langchain_mistralai-0.0.3.tar.gz", hash = "sha256:2e45ee0118df8e4b5577ce8c4f89743059801e473f40a8b7c89cb99dd715f423"},
]
[package.dependencies]
langchain-core = ">=0.1,<0.2"
mistralai = ">=0.0.11,<0.0.12"
[[package]]
name = "langdetect"
version = "1.0.9"
@@ -3458,6 +3498,22 @@ files = [
certifi = "*"
urllib3 = "*"
[[package]]
name = "mistralai"
version = "0.0.11"
description = ""
optional = true
python-versions = ">=3.8,<4.0"
files = [
{file = "mistralai-0.0.11-py3-none-any.whl", hash = "sha256:fb2a240a3985420c4e7db48eb5077d6d6dbc5e83cac0dd948c20342fb48087ee"},
{file = "mistralai-0.0.11.tar.gz", hash = "sha256:383072715531198305dab829ab3749b64933bbc2549354f3c9ebc43c17b912cf"},
]
[package.dependencies]
httpx = ">=0.25.2,<0.26.0"
orjson = ">=3.9.10,<4.0.0"
pydantic = ">=2.5.2,<3.0.0"
[[package]]
name = "mock"
version = "5.1.0"
@@ -4294,6 +4350,65 @@ files = [
{file = "opentelemetry_semantic_conventions-0.42b0.tar.gz", hash = "sha256:44ae67a0a3252a05072877857e5cc1242c98d4cf12870159f1a94bec800d38ec"},
]
[[package]]
name = "orjson"
version = "3.9.12"
description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
optional = true
python-versions = ">=3.8"
files = [
{file = "orjson-3.9.12-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6b4e2bed7d00753c438e83b613923afdd067564ff7ed696bfe3a7b073a236e07"},
{file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd1b8ec63f0bf54a50b498eedeccdca23bd7b658f81c524d18e410c203189365"},
{file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ab8add018a53665042a5ae68200f1ad14c7953fa12110d12d41166f111724656"},
{file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12756a108875526b76e505afe6d6ba34960ac6b8c5ec2f35faf73ef161e97e07"},
{file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:890e7519c0c70296253660455f77e3a194554a3c45e42aa193cdebc76a02d82b"},
{file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d664880d7f016efbae97c725b243b33c2cbb4851ddc77f683fd1eec4a7894146"},
{file = "orjson-3.9.12-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:cfdaede0fa5b500314ec7b1249c7e30e871504a57004acd116be6acdda3b8ab3"},
{file = "orjson-3.9.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6492ff5953011e1ba9ed1bf086835fd574bd0a3cbe252db8e15ed72a30479081"},
{file = "orjson-3.9.12-cp310-none-win32.whl", hash = "sha256:29bf08e2eadb2c480fdc2e2daae58f2f013dff5d3b506edd1e02963b9ce9f8a9"},
{file = "orjson-3.9.12-cp310-none-win_amd64.whl", hash = "sha256:0fc156fba60d6b50743337ba09f052d8afc8b64595112996d22f5fce01ab57da"},
{file = "orjson-3.9.12-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:2849f88a0a12b8d94579b67486cbd8f3a49e36a4cb3d3f0ab352c596078c730c"},
{file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3186b18754befa660b31c649a108a915493ea69b4fc33f624ed854ad3563ac65"},
{file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbbf313c9fb9d4f6cf9c22ced4b6682230457741daeb3d7060c5d06c2e73884a"},
{file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99e8cd005b3926c3db9b63d264bd05e1bf4451787cc79a048f27f5190a9a0311"},
{file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59feb148392d9155f3bfed0a2a3209268e000c2c3c834fb8fe1a6af9392efcbf"},
{file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4ae815a172a1f073b05b9e04273e3b23e608a0858c4e76f606d2d75fcabde0c"},
{file = "orjson-3.9.12-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ed398f9a9d5a1bf55b6e362ffc80ac846af2122d14a8243a1e6510a4eabcb71e"},
{file = "orjson-3.9.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d3cfb76600c5a1e6be91326b8f3b83035a370e727854a96d801c1ea08b708073"},
{file = "orjson-3.9.12-cp311-none-win32.whl", hash = "sha256:a2b6f5252c92bcab3b742ddb3ac195c0fa74bed4319acd74f5d54d79ef4715dc"},
{file = "orjson-3.9.12-cp311-none-win_amd64.whl", hash = "sha256:c95488e4aa1d078ff5776b58f66bd29d628fa59adcb2047f4efd3ecb2bd41a71"},
{file = "orjson-3.9.12-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d6ce2062c4af43b92b0221ed4f445632c6bf4213f8a7da5396a122931377acd9"},
{file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:950951799967558c214cd6cceb7ceceed6f81d2c3c4135ee4a2c9c69f58aa225"},
{file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2dfaf71499d6fd4153f5c86eebb68e3ec1bf95851b030a4b55c7637a37bbdee4"},
{file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:659a8d7279e46c97661839035a1a218b61957316bf0202674e944ac5cfe7ed83"},
{file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af17fa87bccad0b7f6fd8ac8f9cbc9ee656b4552783b10b97a071337616db3e4"},
{file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd52dec9eddf4c8c74392f3fd52fa137b5f2e2bed1d9ae958d879de5f7d7cded"},
{file = "orjson-3.9.12-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:640e2b5d8e36b970202cfd0799d11a9a4ab46cf9212332cd642101ec952df7c8"},
{file = "orjson-3.9.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:daa438bd8024e03bcea2c5a92cd719a663a58e223fba967296b6ab9992259dbf"},
{file = "orjson-3.9.12-cp312-none-win_amd64.whl", hash = "sha256:1bb8f657c39ecdb924d02e809f992c9aafeb1ad70127d53fb573a6a6ab59d549"},
{file = "orjson-3.9.12-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:f4098c7674901402c86ba6045a551a2ee345f9f7ed54eeffc7d86d155c8427e5"},
{file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5586a533998267458fad3a457d6f3cdbddbcce696c916599fa8e2a10a89b24d3"},
{file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:54071b7398cd3f90e4bb61df46705ee96cb5e33e53fc0b2f47dbd9b000e238e1"},
{file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:67426651faa671b40443ea6f03065f9c8e22272b62fa23238b3efdacd301df31"},
{file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4a0cd56e8ee56b203abae7d482ac0d233dbfb436bb2e2d5cbcb539fe1200a312"},
{file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a84a0c3d4841a42e2571b1c1ead20a83e2792644c5827a606c50fc8af7ca4bee"},
{file = "orjson-3.9.12-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:09d60450cda3fa6c8ed17770c3a88473a16460cd0ff2ba74ef0df663b6fd3bb8"},
{file = "orjson-3.9.12-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bc82a4db9934a78ade211cf2e07161e4f068a461c1796465d10069cb50b32a80"},
{file = "orjson-3.9.12-cp38-none-win32.whl", hash = "sha256:61563d5d3b0019804d782137a4f32c72dc44c84e7d078b89d2d2a1adbaa47b52"},
{file = "orjson-3.9.12-cp38-none-win_amd64.whl", hash = "sha256:410f24309fbbaa2fab776e3212a81b96a1ec6037259359a32ea79fbccfcf76aa"},
{file = "orjson-3.9.12-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e773f251258dd82795fd5daeac081d00b97bacf1548e44e71245543374874bcf"},
{file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b159baecfda51c840a619948c25817d37733a4d9877fea96590ef8606468b362"},
{file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:975e72e81a249174840d5a8df977d067b0183ef1560a32998be340f7e195c730"},
{file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:06e42e899dde61eb1851a9fad7f1a21b8e4be063438399b63c07839b57668f6c"},
{file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c157e999e5694475a5515942aebeed6e43f7a1ed52267c1c93dcfde7d78d421"},
{file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dde1bc7c035f2d03aa49dc8642d9c6c9b1a81f2470e02055e76ed8853cfae0c3"},
{file = "orjson-3.9.12-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b0e9d73cdbdad76a53a48f563447e0e1ce34bcecef4614eb4b146383e6e7d8c9"},
{file = "orjson-3.9.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:96e44b21fe407b8ed48afbb3721f3c8c8ce17e345fbe232bd4651ace7317782d"},
{file = "orjson-3.9.12-cp39-none-win32.whl", hash = "sha256:cbd0f3555205bf2a60f8812133f2452d498dbefa14423ba90fe89f32276f7abf"},
{file = "orjson-3.9.12-cp39-none-win_amd64.whl", hash = "sha256:03ea7ee7e992532c2f4a06edd7ee1553f0644790553a118e003e3c405add41fa"},
{file = "orjson-3.9.12.tar.gz", hash = "sha256:da908d23a3b3243632b523344403b128722a5f45e278a8343c2bb67538dff0e4"},
]
[[package]]
name = "overrides"
version = "7.4.0"
@@ -8110,6 +8225,7 @@ googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-
huggingface-hub = ["huggingface_hub"]
llama2 = ["replicate"]
milvus = ["pymilvus"]
mistralai = ["langchain-mistralai"]
modal = ["modal"]
mysql = ["mysql-connector-python"]
opensearch = ["opensearch-py"]
@@ -8130,4 +8246,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.12"
content-hash = "02bd85e14374a9dc9b59523b8fb4baea7068251976ba7f87722cac94a9974ccc"
content-hash = "cb0da55af7c61300bb321770ed319c900b6b3ba3865421d63eb9120beb73d06c"

View File

@@ -149,6 +149,7 @@ google-auth-oauthlib = { version = "^1.2.0", optional = true }
google-auth = { version = "^2.25.2", optional = true }
google-auth-httplib2 = { version = "^0.2.0", optional = true }
google-api-core = { version = "^2.15.0", optional = true }
langchain-mistralai = { version = "^0.0.3", optional = true }
[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
@@ -214,6 +215,7 @@ rss_feed = [
google = ["google-generativeai"]
modal = ["modal"]
dropbox = ["dropbox"]
mistralai = ["langchain-mistralai"]
[tool.poetry.group.docs.dependencies]

View File

@@ -0,0 +1,60 @@
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.mistralai import MistralAILlm
@pytest.fixture
def mistralai_llm_config(monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
yield BaseLlmConfig(model="mistral-tiny", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
def test_mistralai_llm_init_missing_api_key(monkeypatch):
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
with pytest.raises(ValueError, match="Please set the MISTRAL_API_KEY environment variable."):
MistralAILlm()
def test_mistralai_llm_init(monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
llm = MistralAILlm()
assert llm is not None
def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
def mock_get_answer(prompt, config):
return "Generated Text"
monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"
def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
mistralai_llm_config.system_prompt = "Test system prompt"
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"
def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("")
assert result == "Generated Text"
def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
mistralai_llm_config.system_prompt = None
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"