[Feature] Add support for Mistral API (#1194)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -20,6 +20,7 @@ Embedchain comes with built-in support for various popular large language models
|
||||
<Card title="Hugging Face" href="#hugging-face"></Card>
|
||||
<Card title="Llama2" href="#llama2"></Card>
|
||||
<Card title="Vertex AI" href="#vertex-ai"></Card>
|
||||
<Card title="Mistral AI" href="#mistral-ai"></Card>
|
||||
</CardGroup>
|
||||
|
||||
## OpenAI
|
||||
@@ -620,5 +621,47 @@ llm:
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
|
||||
## Mistral AI
|
||||
|
||||
Obtain the Mistral AI api key from their [console](https://console.mistral.ai/).
|
||||
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import App
|
||||
|
||||
os.environ["MISTRAL_API_KEY"] = "xxx"
|
||||
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
|
||||
app.add("https://www.forbes.com/profile/elon-musk")
|
||||
|
||||
response = app.query("what is the net worth of Elon Musk?")
|
||||
# As of January 16, 2024, Elon Musk's net worth is $225.4 billion.
|
||||
|
||||
response = app.chat("which companies does elon own?")
|
||||
# Elon Musk owns Tesla, SpaceX, Boring Company, Twitter, and X.
|
||||
|
||||
response = app.chat("what question did I ask you already?")
|
||||
# You have asked me several times already which companies Elon Musk owns, specifically Tesla, SpaceX, Boring Company, Twitter, and X.
|
||||
```
|
||||
|
||||
```yaml config.yaml
|
||||
llm:
|
||||
provider: mistralai
|
||||
config:
|
||||
model: mistral-tiny
|
||||
temperature: 0.5
|
||||
max_tokens: 1000
|
||||
top_p: 1
|
||||
embedder:
|
||||
provider: mistralai
|
||||
config:
|
||||
model: mistral-embed
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
<br/ >
|
||||
<Snippet file="missing-llm-tip.mdx" />
|
||||
|
||||
46
embedchain/embedder/mistralai.py
Normal file
46
embedchain/embedder/mistralai.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
|
||||
class MistralAIEmbeddingFunction(EmbeddingFunction):
|
||||
def __init__(self, config: BaseEmbedderConfig) -> None:
|
||||
super().__init__()
|
||||
try:
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for MistralAI are not installed."
|
||||
'Please install with `pip install --upgrade "embedchain[mistralai]"`'
|
||||
) from None
|
||||
self.config = config
|
||||
api_key = self.config.api_key or os.getenv("MISTRAL_API_KEY")
|
||||
self.client = MistralAIEmbeddings(mistral_api_key=api_key)
|
||||
self.client.model = self.config.model
|
||||
|
||||
def __call__(self, input: Union[list[str], str]) -> Embeddings:
|
||||
if isinstance(input, str):
|
||||
input_ = [input]
|
||||
else:
|
||||
input_ = input
|
||||
response = self.client.embed_documents(input_)
|
||||
return response
|
||||
|
||||
|
||||
class MistralAIEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if self.config.model is None:
|
||||
self.config.model = "mistral-embed"
|
||||
|
||||
embedding_fn = MistralAIEmbeddingFunction(config=self.config)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = self.config.vector_dimension or VectorDimensions.MISTRAL_AI.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
@@ -21,6 +21,7 @@ class LlmFactory:
|
||||
"openai": "embedchain.llm.openai.OpenAILlm",
|
||||
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
|
||||
"google": "embedchain.llm.google.GoogleLlm",
|
||||
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
|
||||
@@ -50,6 +51,7 @@ class EmbedderFactory:
|
||||
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
|
||||
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
|
||||
"google": "embedchain.embedder.google.GoogleAIEmbedder",
|
||||
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||
|
||||
52
embedchain/llm/mistralai.py
Normal file
52
embedchain/llm/mistralai.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class MistralAILlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return MistralAILlm._get_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig):
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for MistralAI are not installed."
|
||||
'Please install with `pip install --upgrade "embedchain[mistralai]"`'
|
||||
) from None
|
||||
|
||||
api_key = config.api_key or os.getenv("MISTRAL_API_KEY")
|
||||
client = ChatMistralAI(mistral_api_key=api_key)
|
||||
messages = []
|
||||
if config.system_prompt:
|
||||
messages.append(SystemMessage(content=config.system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
kwargs = {
|
||||
"model": config.model or "mistral-tiny",
|
||||
"temperature": config.temperature,
|
||||
"max_tokens": config.max_tokens,
|
||||
"top_p": config.top_p,
|
||||
}
|
||||
|
||||
# TODO: Add support for streaming
|
||||
if config.stream:
|
||||
answer = ""
|
||||
for chunk in client.stream(**kwargs, input=messages):
|
||||
answer += chunk.content
|
||||
return answer
|
||||
else:
|
||||
response = client.invoke(**kwargs, input=messages)
|
||||
answer = response.content
|
||||
return answer
|
||||
@@ -8,3 +8,4 @@ class VectorDimensions(Enum):
|
||||
VERTEX_AI = 768
|
||||
HUGGING_FACE = 384
|
||||
GOOGLE_AI = 768
|
||||
MISTRAL_AI = 1024
|
||||
|
||||
@@ -406,6 +406,7 @@ def validate_config(config_data):
|
||||
"llama2",
|
||||
"vertexai",
|
||||
"google",
|
||||
"mistralai",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): str,
|
||||
@@ -431,7 +432,15 @@ def validate_config(config_data):
|
||||
Optional("config"): object, # TODO: add particular config schema for each provider
|
||||
},
|
||||
Optional("embedder"): {
|
||||
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
|
||||
Optional("provider"): Or(
|
||||
"openai",
|
||||
"gpt4all",
|
||||
"huggingface",
|
||||
"vertexai",
|
||||
"azure_openai",
|
||||
"google",
|
||||
"mistralai",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): Optional(str),
|
||||
Optional("deployment_name"): Optional(str),
|
||||
@@ -442,7 +451,15 @@ def validate_config(config_data):
|
||||
},
|
||||
},
|
||||
Optional("embedding_model"): {
|
||||
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
|
||||
Optional("provider"): Or(
|
||||
"openai",
|
||||
"gpt4all",
|
||||
"huggingface",
|
||||
"vertexai",
|
||||
"azure_openai",
|
||||
"google",
|
||||
"mistralai",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): str,
|
||||
Optional("deployment_name"): str,
|
||||
|
||||
136
poetry.lock
generated
136
poetry.lock
generated
@@ -2487,24 +2487,24 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "0.18.0"
|
||||
version = "1.0.2"
|
||||
description = "A minimal low-level HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpcore-0.18.0-py3-none-any.whl", hash = "sha256:adc5398ee0a476567bf87467063ee63584a8bce86078bf748e48754f60202ced"},
|
||||
{file = "httpcore-0.18.0.tar.gz", hash = "sha256:13b5e5cd1dca1a6636a6aaea212b19f4f85cd88c366a2b82304181b769aab3c9"},
|
||||
{file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"},
|
||||
{file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=3.0,<5.0"
|
||||
certifi = "*"
|
||||
h11 = ">=0.13,<0.15"
|
||||
sniffio = "==1.*"
|
||||
|
||||
[package.extras]
|
||||
asyncio = ["anyio (>=4.0,<5.0)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
trio = ["trio (>=0.22.0,<0.23.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httplib2"
|
||||
@@ -2569,21 +2569,22 @@ test = ["Cython (>=0.29.24,<0.30.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx"
|
||||
version = "0.25.0"
|
||||
version = "0.25.2"
|
||||
description = "The next generation HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpx-0.25.0-py3-none-any.whl", hash = "sha256:181ea7f8ba3a82578be86ef4171554dd45fec26a02556a744db029a0a27b7100"},
|
||||
{file = "httpx-0.25.0.tar.gz", hash = "sha256:47ecda285389cb32bb2691cc6e069e3ab0205956f681c5b2ad2325719751d875"},
|
||||
{file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"},
|
||||
{file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = "*"
|
||||
brotli = {version = "*", optional = true, markers = "platform_python_implementation == \"CPython\" and extra == \"brotli\""}
|
||||
brotlicffi = {version = "*", optional = true, markers = "platform_python_implementation != \"CPython\" and extra == \"brotli\""}
|
||||
certifi = "*"
|
||||
h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""}
|
||||
httpcore = ">=0.18.0,<0.19.0"
|
||||
httpcore = "==1.*"
|
||||
idna = "*"
|
||||
sniffio = "*"
|
||||
socksio = {version = "==1.*", optional = true, markers = "extra == \"socks\""}
|
||||
@@ -3024,6 +3025,45 @@ openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"]
|
||||
qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"]
|
||||
text-helpers = ["chardet (>=5.1.0,<6.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.12"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = true
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
files = [
|
||||
{file = "langchain_core-0.1.12-py3-none-any.whl", hash = "sha256:d11c6262f7a9deff7de8fdf14498b8a951020dfed3a80f2358ab731ad04abef0"},
|
||||
{file = "langchain_core-0.1.12.tar.gz", hash = "sha256:f18e9300e9a07589b3e280e51befbc5a4513f535949406e55eb7a2dc40c3ce66"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=3,<5"
|
||||
jsonpatch = ">=1.33,<2.0"
|
||||
langsmith = ">=0.0.63,<0.1.0"
|
||||
packaging = ">=23.2,<24.0"
|
||||
pydantic = ">=1,<3"
|
||||
PyYAML = ">=5.3"
|
||||
requests = ">=2,<3"
|
||||
tenacity = ">=8.1.0,<9.0.0"
|
||||
|
||||
[package.extras]
|
||||
extended-testing = ["jinja2 (>=3,<4)"]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-mistralai"
|
||||
version = "0.0.3"
|
||||
description = "An integration package connecting Mistral and LangChain"
|
||||
optional = true
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
files = [
|
||||
{file = "langchain_mistralai-0.0.3-py3-none-any.whl", hash = "sha256:ebb8ba3d7978b5ee16f7e09512ffa434e00bc9863f1537f1a5f5203882d99619"},
|
||||
{file = "langchain_mistralai-0.0.3.tar.gz", hash = "sha256:2e45ee0118df8e4b5577ce8c4f89743059801e473f40a8b7c89cb99dd715f423"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
langchain-core = ">=0.1,<0.2"
|
||||
mistralai = ">=0.0.11,<0.0.12"
|
||||
|
||||
[[package]]
|
||||
name = "langdetect"
|
||||
version = "1.0.9"
|
||||
@@ -3458,6 +3498,22 @@ files = [
|
||||
certifi = "*"
|
||||
urllib3 = "*"
|
||||
|
||||
[[package]]
|
||||
name = "mistralai"
|
||||
version = "0.0.11"
|
||||
description = ""
|
||||
optional = true
|
||||
python-versions = ">=3.8,<4.0"
|
||||
files = [
|
||||
{file = "mistralai-0.0.11-py3-none-any.whl", hash = "sha256:fb2a240a3985420c4e7db48eb5077d6d6dbc5e83cac0dd948c20342fb48087ee"},
|
||||
{file = "mistralai-0.0.11.tar.gz", hash = "sha256:383072715531198305dab829ab3749b64933bbc2549354f3c9ebc43c17b912cf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.25.2,<0.26.0"
|
||||
orjson = ">=3.9.10,<4.0.0"
|
||||
pydantic = ">=2.5.2,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "mock"
|
||||
version = "5.1.0"
|
||||
@@ -4294,6 +4350,65 @@ files = [
|
||||
{file = "opentelemetry_semantic_conventions-0.42b0.tar.gz", hash = "sha256:44ae67a0a3252a05072877857e5cc1242c98d4cf12870159f1a94bec800d38ec"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "orjson"
|
||||
version = "3.9.12"
|
||||
description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "orjson-3.9.12-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6b4e2bed7d00753c438e83b613923afdd067564ff7ed696bfe3a7b073a236e07"},
|
||||
{file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd1b8ec63f0bf54a50b498eedeccdca23bd7b658f81c524d18e410c203189365"},
|
||||
{file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ab8add018a53665042a5ae68200f1ad14c7953fa12110d12d41166f111724656"},
|
||||
{file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12756a108875526b76e505afe6d6ba34960ac6b8c5ec2f35faf73ef161e97e07"},
|
||||
{file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:890e7519c0c70296253660455f77e3a194554a3c45e42aa193cdebc76a02d82b"},
|
||||
{file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d664880d7f016efbae97c725b243b33c2cbb4851ddc77f683fd1eec4a7894146"},
|
||||
{file = "orjson-3.9.12-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:cfdaede0fa5b500314ec7b1249c7e30e871504a57004acd116be6acdda3b8ab3"},
|
||||
{file = "orjson-3.9.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6492ff5953011e1ba9ed1bf086835fd574bd0a3cbe252db8e15ed72a30479081"},
|
||||
{file = "orjson-3.9.12-cp310-none-win32.whl", hash = "sha256:29bf08e2eadb2c480fdc2e2daae58f2f013dff5d3b506edd1e02963b9ce9f8a9"},
|
||||
{file = "orjson-3.9.12-cp310-none-win_amd64.whl", hash = "sha256:0fc156fba60d6b50743337ba09f052d8afc8b64595112996d22f5fce01ab57da"},
|
||||
{file = "orjson-3.9.12-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:2849f88a0a12b8d94579b67486cbd8f3a49e36a4cb3d3f0ab352c596078c730c"},
|
||||
{file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3186b18754befa660b31c649a108a915493ea69b4fc33f624ed854ad3563ac65"},
|
||||
{file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbbf313c9fb9d4f6cf9c22ced4b6682230457741daeb3d7060c5d06c2e73884a"},
|
||||
{file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99e8cd005b3926c3db9b63d264bd05e1bf4451787cc79a048f27f5190a9a0311"},
|
||||
{file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59feb148392d9155f3bfed0a2a3209268e000c2c3c834fb8fe1a6af9392efcbf"},
|
||||
{file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4ae815a172a1f073b05b9e04273e3b23e608a0858c4e76f606d2d75fcabde0c"},
|
||||
{file = "orjson-3.9.12-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ed398f9a9d5a1bf55b6e362ffc80ac846af2122d14a8243a1e6510a4eabcb71e"},
|
||||
{file = "orjson-3.9.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d3cfb76600c5a1e6be91326b8f3b83035a370e727854a96d801c1ea08b708073"},
|
||||
{file = "orjson-3.9.12-cp311-none-win32.whl", hash = "sha256:a2b6f5252c92bcab3b742ddb3ac195c0fa74bed4319acd74f5d54d79ef4715dc"},
|
||||
{file = "orjson-3.9.12-cp311-none-win_amd64.whl", hash = "sha256:c95488e4aa1d078ff5776b58f66bd29d628fa59adcb2047f4efd3ecb2bd41a71"},
|
||||
{file = "orjson-3.9.12-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d6ce2062c4af43b92b0221ed4f445632c6bf4213f8a7da5396a122931377acd9"},
|
||||
{file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:950951799967558c214cd6cceb7ceceed6f81d2c3c4135ee4a2c9c69f58aa225"},
|
||||
{file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2dfaf71499d6fd4153f5c86eebb68e3ec1bf95851b030a4b55c7637a37bbdee4"},
|
||||
{file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:659a8d7279e46c97661839035a1a218b61957316bf0202674e944ac5cfe7ed83"},
|
||||
{file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af17fa87bccad0b7f6fd8ac8f9cbc9ee656b4552783b10b97a071337616db3e4"},
|
||||
{file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd52dec9eddf4c8c74392f3fd52fa137b5f2e2bed1d9ae958d879de5f7d7cded"},
|
||||
{file = "orjson-3.9.12-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:640e2b5d8e36b970202cfd0799d11a9a4ab46cf9212332cd642101ec952df7c8"},
|
||||
{file = "orjson-3.9.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:daa438bd8024e03bcea2c5a92cd719a663a58e223fba967296b6ab9992259dbf"},
|
||||
{file = "orjson-3.9.12-cp312-none-win_amd64.whl", hash = "sha256:1bb8f657c39ecdb924d02e809f992c9aafeb1ad70127d53fb573a6a6ab59d549"},
|
||||
{file = "orjson-3.9.12-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:f4098c7674901402c86ba6045a551a2ee345f9f7ed54eeffc7d86d155c8427e5"},
|
||||
{file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5586a533998267458fad3a457d6f3cdbddbcce696c916599fa8e2a10a89b24d3"},
|
||||
{file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:54071b7398cd3f90e4bb61df46705ee96cb5e33e53fc0b2f47dbd9b000e238e1"},
|
||||
{file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:67426651faa671b40443ea6f03065f9c8e22272b62fa23238b3efdacd301df31"},
|
||||
{file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4a0cd56e8ee56b203abae7d482ac0d233dbfb436bb2e2d5cbcb539fe1200a312"},
|
||||
{file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a84a0c3d4841a42e2571b1c1ead20a83e2792644c5827a606c50fc8af7ca4bee"},
|
||||
{file = "orjson-3.9.12-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:09d60450cda3fa6c8ed17770c3a88473a16460cd0ff2ba74ef0df663b6fd3bb8"},
|
||||
{file = "orjson-3.9.12-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bc82a4db9934a78ade211cf2e07161e4f068a461c1796465d10069cb50b32a80"},
|
||||
{file = "orjson-3.9.12-cp38-none-win32.whl", hash = "sha256:61563d5d3b0019804d782137a4f32c72dc44c84e7d078b89d2d2a1adbaa47b52"},
|
||||
{file = "orjson-3.9.12-cp38-none-win_amd64.whl", hash = "sha256:410f24309fbbaa2fab776e3212a81b96a1ec6037259359a32ea79fbccfcf76aa"},
|
||||
{file = "orjson-3.9.12-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e773f251258dd82795fd5daeac081d00b97bacf1548e44e71245543374874bcf"},
|
||||
{file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b159baecfda51c840a619948c25817d37733a4d9877fea96590ef8606468b362"},
|
||||
{file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:975e72e81a249174840d5a8df977d067b0183ef1560a32998be340f7e195c730"},
|
||||
{file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:06e42e899dde61eb1851a9fad7f1a21b8e4be063438399b63c07839b57668f6c"},
|
||||
{file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c157e999e5694475a5515942aebeed6e43f7a1ed52267c1c93dcfde7d78d421"},
|
||||
{file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dde1bc7c035f2d03aa49dc8642d9c6c9b1a81f2470e02055e76ed8853cfae0c3"},
|
||||
{file = "orjson-3.9.12-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b0e9d73cdbdad76a53a48f563447e0e1ce34bcecef4614eb4b146383e6e7d8c9"},
|
||||
{file = "orjson-3.9.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:96e44b21fe407b8ed48afbb3721f3c8c8ce17e345fbe232bd4651ace7317782d"},
|
||||
{file = "orjson-3.9.12-cp39-none-win32.whl", hash = "sha256:cbd0f3555205bf2a60f8812133f2452d498dbefa14423ba90fe89f32276f7abf"},
|
||||
{file = "orjson-3.9.12-cp39-none-win_amd64.whl", hash = "sha256:03ea7ee7e992532c2f4a06edd7ee1553f0644790553a118e003e3c405add41fa"},
|
||||
{file = "orjson-3.9.12.tar.gz", hash = "sha256:da908d23a3b3243632b523344403b128722a5f45e278a8343c2bb67538dff0e4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "overrides"
|
||||
version = "7.4.0"
|
||||
@@ -8110,6 +8225,7 @@ googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-
|
||||
huggingface-hub = ["huggingface_hub"]
|
||||
llama2 = ["replicate"]
|
||||
milvus = ["pymilvus"]
|
||||
mistralai = ["langchain-mistralai"]
|
||||
modal = ["modal"]
|
||||
mysql = ["mysql-connector-python"]
|
||||
opensearch = ["opensearch-py"]
|
||||
@@ -8130,4 +8246,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.12"
|
||||
content-hash = "02bd85e14374a9dc9b59523b8fb4baea7068251976ba7f87722cac94a9974ccc"
|
||||
content-hash = "cb0da55af7c61300bb321770ed319c900b6b3ba3865421d63eb9120beb73d06c"
|
||||
|
||||
@@ -149,6 +149,7 @@ google-auth-oauthlib = { version = "^1.2.0", optional = true }
|
||||
google-auth = { version = "^2.25.2", optional = true }
|
||||
google-auth-httplib2 = { version = "^0.2.0", optional = true }
|
||||
google-api-core = { version = "^2.15.0", optional = true }
|
||||
langchain-mistralai = { version = "^0.0.3", optional = true }
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.3.0"
|
||||
@@ -214,6 +215,7 @@ rss_feed = [
|
||||
google = ["google-generativeai"]
|
||||
modal = ["modal"]
|
||||
dropbox = ["dropbox"]
|
||||
mistralai = ["langchain-mistralai"]
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
|
||||
|
||||
60
tests/llm/test_mistralai.py
Normal file
60
tests/llm/test_mistralai.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.mistralai import MistralAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mistralai_llm_config(monkeypatch):
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
|
||||
yield BaseLlmConfig(model="mistral-tiny", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
|
||||
|
||||
def test_mistralai_llm_init_missing_api_key(monkeypatch):
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="Please set the MISTRAL_API_KEY environment variable."):
|
||||
MistralAILlm()
|
||||
|
||||
|
||||
def test_mistralai_llm_init(monkeypatch):
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
|
||||
llm = MistralAILlm()
|
||||
assert llm is not None
|
||||
|
||||
|
||||
def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
|
||||
def mock_get_answer(prompt, config):
|
||||
return "Generated Text"
|
||||
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
|
||||
mistralai_llm_config.system_prompt = "Test system prompt"
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("")
|
||||
|
||||
assert result == "Generated Text"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
|
||||
mistralai_llm_config.system_prompt = None
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text")
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
Reference in New Issue
Block a user