From cb0499407e9c48dbad521f2c2289fd2f40f72a8b Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Sat, 20 Jan 2024 12:31:50 +0530 Subject: [PATCH] [Feature] Add support for Mistral API (#1194) Co-authored-by: Deven Patel --- docs/components/llms.mdx | 43 ++++++++ embedchain/embedder/mistralai.py | 46 +++++++++ embedchain/factory.py | 2 + embedchain/llm/mistralai.py | 52 ++++++++++ embedchain/models/vector_dimensions.py | 1 + embedchain/utils/misc.py | 21 +++- poetry.lock | 136 +++++++++++++++++++++++-- pyproject.toml | 2 + tests/llm/test_mistralai.py | 60 +++++++++++ 9 files changed, 351 insertions(+), 12 deletions(-) create mode 100644 embedchain/embedder/mistralai.py create mode 100644 embedchain/llm/mistralai.py create mode 100644 tests/llm/test_mistralai.py diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx index bbf05a36..b9b4b138 100644 --- a/docs/components/llms.mdx +++ b/docs/components/llms.mdx @@ -20,6 +20,7 @@ Embedchain comes with built-in support for various popular large language models + ## OpenAI @@ -620,5 +621,47 @@ llm: ``` + +## Mistral AI + +Obtain the Mistral AI api key from their [console](https://console.mistral.ai/). + + + +```python main.py +import os +from embedchain import App + +os.environ["MISTRAL_API_KEY"] = "xxx" + +app = App.from_config(config_path="config.yaml") + +app.add("https://www.forbes.com/profile/elon-musk") + +response = app.query("what is the net worth of Elon Musk?") +# As of January 16, 2024, Elon Musk's net worth is $225.4 billion. + +response = app.chat("which companies does elon own?") +# Elon Musk owns Tesla, SpaceX, Boring Company, Twitter, and X. + +response = app.chat("what question did I ask you already?") +# You have asked me several times already which companies Elon Musk owns, specifically Tesla, SpaceX, Boring Company, Twitter, and X. +``` + +```yaml config.yaml +llm: + provider: mistralai + config: + model: mistral-tiny + temperature: 0.5 + max_tokens: 1000 + top_p: 1 +embedder: + provider: mistralai + config: + model: mistral-embed +``` + +
diff --git a/embedchain/embedder/mistralai.py b/embedchain/embedder/mistralai.py new file mode 100644 index 00000000..29db72ae --- /dev/null +++ b/embedchain/embedder/mistralai.py @@ -0,0 +1,46 @@ +import os +from typing import Optional, Union + +from chromadb import EmbeddingFunction, Embeddings + +from embedchain.config import BaseEmbedderConfig +from embedchain.embedder.base import BaseEmbedder +from embedchain.models import VectorDimensions + + +class MistralAIEmbeddingFunction(EmbeddingFunction): + def __init__(self, config: BaseEmbedderConfig) -> None: + super().__init__() + try: + from langchain_mistralai import MistralAIEmbeddings + except ModuleNotFoundError: + raise ModuleNotFoundError( + "The required dependencies for MistralAI are not installed." + 'Please install with `pip install --upgrade "embedchain[mistralai]"`' + ) from None + self.config = config + api_key = self.config.api_key or os.getenv("MISTRAL_API_KEY") + self.client = MistralAIEmbeddings(mistral_api_key=api_key) + self.client.model = self.config.model + + def __call__(self, input: Union[list[str], str]) -> Embeddings: + if isinstance(input, str): + input_ = [input] + else: + input_ = input + response = self.client.embed_documents(input_) + return response + + +class MistralAIEmbedder(BaseEmbedder): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config) + + if self.config.model is None: + self.config.model = "mistral-embed" + + embedding_fn = MistralAIEmbeddingFunction(config=self.config) + self.set_embedding_fn(embedding_fn=embedding_fn) + + vector_dimension = self.config.vector_dimension or VectorDimensions.MISTRAL_AI.value + self.set_vector_dimension(vector_dimension=vector_dimension) diff --git a/embedchain/factory.py b/embedchain/factory.py index 9fe9fb5b..04b20c56 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -21,6 +21,7 @@ class LlmFactory: "openai": "embedchain.llm.openai.OpenAILlm", "vertexai": "embedchain.llm.vertex_ai.VertexAILlm", "google": "embedchain.llm.google.GoogleLlm", + "mistralai": "embedchain.llm.mistralai.MistralAILlm", } provider_to_config_class = { "embedchain": "embedchain.config.llm.base.BaseLlmConfig", @@ -50,6 +51,7 @@ class EmbedderFactory: "openai": "embedchain.embedder.openai.OpenAIEmbedder", "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder", "google": "embedchain.embedder.google.GoogleAIEmbedder", + "mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder", } provider_to_config_class = { "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig", diff --git a/embedchain/llm/mistralai.py b/embedchain/llm/mistralai.py new file mode 100644 index 00000000..6f2366bc --- /dev/null +++ b/embedchain/llm/mistralai.py @@ -0,0 +1,52 @@ +import os +from typing import Optional + +from embedchain.config import BaseLlmConfig +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.llm.base import BaseLlm + + +@register_deserializable +class MistralAILlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) + if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ: + raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.") + + def get_llm_model_answer(self, prompt): + return MistralAILlm._get_answer(prompt=prompt, config=self.config) + + @staticmethod + def _get_answer(prompt: str, config: BaseLlmConfig): + try: + from langchain_core.messages import HumanMessage, SystemMessage + from langchain_mistralai.chat_models import ChatMistralAI + except ModuleNotFoundError: + raise ModuleNotFoundError( + "The required dependencies for MistralAI are not installed." + 'Please install with `pip install --upgrade "embedchain[mistralai]"`' + ) from None + + api_key = config.api_key or os.getenv("MISTRAL_API_KEY") + client = ChatMistralAI(mistral_api_key=api_key) + messages = [] + if config.system_prompt: + messages.append(SystemMessage(content=config.system_prompt)) + messages.append(HumanMessage(content=prompt)) + kwargs = { + "model": config.model or "mistral-tiny", + "temperature": config.temperature, + "max_tokens": config.max_tokens, + "top_p": config.top_p, + } + + # TODO: Add support for streaming + if config.stream: + answer = "" + for chunk in client.stream(**kwargs, input=messages): + answer += chunk.content + return answer + else: + response = client.invoke(**kwargs, input=messages) + answer = response.content + return answer diff --git a/embedchain/models/vector_dimensions.py b/embedchain/models/vector_dimensions.py index 2bdaa0fa..e274d82f 100644 --- a/embedchain/models/vector_dimensions.py +++ b/embedchain/models/vector_dimensions.py @@ -8,3 +8,4 @@ class VectorDimensions(Enum): VERTEX_AI = 768 HUGGING_FACE = 384 GOOGLE_AI = 768 + MISTRAL_AI = 1024 diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index c08665c2..6cc13ade 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -406,6 +406,7 @@ def validate_config(config_data): "llama2", "vertexai", "google", + "mistralai", ), Optional("config"): { Optional("model"): str, @@ -431,7 +432,15 @@ def validate_config(config_data): Optional("config"): object, # TODO: add particular config schema for each provider }, Optional("embedder"): { - Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"), + Optional("provider"): Or( + "openai", + "gpt4all", + "huggingface", + "vertexai", + "azure_openai", + "google", + "mistralai", + ), Optional("config"): { Optional("model"): Optional(str), Optional("deployment_name"): Optional(str), @@ -442,7 +451,15 @@ def validate_config(config_data): }, }, Optional("embedding_model"): { - Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"), + Optional("provider"): Or( + "openai", + "gpt4all", + "huggingface", + "vertexai", + "azure_openai", + "google", + "mistralai", + ), Optional("config"): { Optional("model"): str, Optional("deployment_name"): str, diff --git a/poetry.lock b/poetry.lock index 0399cf16..b2599536 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2487,24 +2487,24 @@ files = [ [[package]] name = "httpcore" -version = "0.18.0" +version = "1.0.2" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpcore-0.18.0-py3-none-any.whl", hash = "sha256:adc5398ee0a476567bf87467063ee63584a8bce86078bf748e48754f60202ced"}, - {file = "httpcore-0.18.0.tar.gz", hash = "sha256:13b5e5cd1dca1a6636a6aaea212b19f4f85cd88c366a2b82304181b769aab3c9"}, + {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, + {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, ] [package.dependencies] -anyio = ">=3.0,<5.0" certifi = "*" h11 = ">=0.13,<0.15" -sniffio = "==1.*" [package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.23.0)"] [[package]] name = "httplib2" @@ -2569,21 +2569,22 @@ test = ["Cython (>=0.29.24,<0.30.0)"] [[package]] name = "httpx" -version = "0.25.0" +version = "0.25.2" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpx-0.25.0-py3-none-any.whl", hash = "sha256:181ea7f8ba3a82578be86ef4171554dd45fec26a02556a744db029a0a27b7100"}, - {file = "httpx-0.25.0.tar.gz", hash = "sha256:47ecda285389cb32bb2691cc6e069e3ab0205956f681c5b2ad2325719751d875"}, + {file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"}, + {file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"}, ] [package.dependencies] +anyio = "*" brotli = {version = "*", optional = true, markers = "platform_python_implementation == \"CPython\" and extra == \"brotli\""} brotlicffi = {version = "*", optional = true, markers = "platform_python_implementation != \"CPython\" and extra == \"brotli\""} certifi = "*" h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""} -httpcore = ">=0.18.0,<0.19.0" +httpcore = "==1.*" idna = "*" sniffio = "*" socksio = {version = "==1.*", optional = true, markers = "extra == \"socks\""} @@ -3024,6 +3025,45 @@ openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"] qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"] text-helpers = ["chardet (>=5.1.0,<6.0.0)"] +[[package]] +name = "langchain-core" +version = "0.1.12" +description = "Building applications with LLMs through composability" +optional = true +python-versions = ">=3.8.1,<4.0" +files = [ + {file = "langchain_core-0.1.12-py3-none-any.whl", hash = "sha256:d11c6262f7a9deff7de8fdf14498b8a951020dfed3a80f2358ab731ad04abef0"}, + {file = "langchain_core-0.1.12.tar.gz", hash = "sha256:f18e9300e9a07589b3e280e51befbc5a4513f535949406e55eb7a2dc40c3ce66"}, +] + +[package.dependencies] +anyio = ">=3,<5" +jsonpatch = ">=1.33,<2.0" +langsmith = ">=0.0.63,<0.1.0" +packaging = ">=23.2,<24.0" +pydantic = ">=1,<3" +PyYAML = ">=5.3" +requests = ">=2,<3" +tenacity = ">=8.1.0,<9.0.0" + +[package.extras] +extended-testing = ["jinja2 (>=3,<4)"] + +[[package]] +name = "langchain-mistralai" +version = "0.0.3" +description = "An integration package connecting Mistral and LangChain" +optional = true +python-versions = ">=3.8.1,<4.0" +files = [ + {file = "langchain_mistralai-0.0.3-py3-none-any.whl", hash = "sha256:ebb8ba3d7978b5ee16f7e09512ffa434e00bc9863f1537f1a5f5203882d99619"}, + {file = "langchain_mistralai-0.0.3.tar.gz", hash = "sha256:2e45ee0118df8e4b5577ce8c4f89743059801e473f40a8b7c89cb99dd715f423"}, +] + +[package.dependencies] +langchain-core = ">=0.1,<0.2" +mistralai = ">=0.0.11,<0.0.12" + [[package]] name = "langdetect" version = "1.0.9" @@ -3458,6 +3498,22 @@ files = [ certifi = "*" urllib3 = "*" +[[package]] +name = "mistralai" +version = "0.0.11" +description = "" +optional = true +python-versions = ">=3.8,<4.0" +files = [ + {file = "mistralai-0.0.11-py3-none-any.whl", hash = "sha256:fb2a240a3985420c4e7db48eb5077d6d6dbc5e83cac0dd948c20342fb48087ee"}, + {file = "mistralai-0.0.11.tar.gz", hash = "sha256:383072715531198305dab829ab3749b64933bbc2549354f3c9ebc43c17b912cf"}, +] + +[package.dependencies] +httpx = ">=0.25.2,<0.26.0" +orjson = ">=3.9.10,<4.0.0" +pydantic = ">=2.5.2,<3.0.0" + [[package]] name = "mock" version = "5.1.0" @@ -4294,6 +4350,65 @@ files = [ {file = "opentelemetry_semantic_conventions-0.42b0.tar.gz", hash = "sha256:44ae67a0a3252a05072877857e5cc1242c98d4cf12870159f1a94bec800d38ec"}, ] +[[package]] +name = "orjson" +version = "3.9.12" +description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" +optional = true +python-versions = ">=3.8" +files = [ + {file = "orjson-3.9.12-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6b4e2bed7d00753c438e83b613923afdd067564ff7ed696bfe3a7b073a236e07"}, + {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd1b8ec63f0bf54a50b498eedeccdca23bd7b658f81c524d18e410c203189365"}, + {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ab8add018a53665042a5ae68200f1ad14c7953fa12110d12d41166f111724656"}, + {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12756a108875526b76e505afe6d6ba34960ac6b8c5ec2f35faf73ef161e97e07"}, + {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:890e7519c0c70296253660455f77e3a194554a3c45e42aa193cdebc76a02d82b"}, + {file = "orjson-3.9.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d664880d7f016efbae97c725b243b33c2cbb4851ddc77f683fd1eec4a7894146"}, + {file = "orjson-3.9.12-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:cfdaede0fa5b500314ec7b1249c7e30e871504a57004acd116be6acdda3b8ab3"}, + {file = "orjson-3.9.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6492ff5953011e1ba9ed1bf086835fd574bd0a3cbe252db8e15ed72a30479081"}, + {file = "orjson-3.9.12-cp310-none-win32.whl", hash = "sha256:29bf08e2eadb2c480fdc2e2daae58f2f013dff5d3b506edd1e02963b9ce9f8a9"}, + {file = "orjson-3.9.12-cp310-none-win_amd64.whl", hash = "sha256:0fc156fba60d6b50743337ba09f052d8afc8b64595112996d22f5fce01ab57da"}, + {file = "orjson-3.9.12-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:2849f88a0a12b8d94579b67486cbd8f3a49e36a4cb3d3f0ab352c596078c730c"}, + {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3186b18754befa660b31c649a108a915493ea69b4fc33f624ed854ad3563ac65"}, + {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbbf313c9fb9d4f6cf9c22ced4b6682230457741daeb3d7060c5d06c2e73884a"}, + {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99e8cd005b3926c3db9b63d264bd05e1bf4451787cc79a048f27f5190a9a0311"}, + {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59feb148392d9155f3bfed0a2a3209268e000c2c3c834fb8fe1a6af9392efcbf"}, + {file = "orjson-3.9.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4ae815a172a1f073b05b9e04273e3b23e608a0858c4e76f606d2d75fcabde0c"}, + {file = "orjson-3.9.12-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ed398f9a9d5a1bf55b6e362ffc80ac846af2122d14a8243a1e6510a4eabcb71e"}, + {file = "orjson-3.9.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d3cfb76600c5a1e6be91326b8f3b83035a370e727854a96d801c1ea08b708073"}, + {file = "orjson-3.9.12-cp311-none-win32.whl", hash = "sha256:a2b6f5252c92bcab3b742ddb3ac195c0fa74bed4319acd74f5d54d79ef4715dc"}, + {file = "orjson-3.9.12-cp311-none-win_amd64.whl", hash = "sha256:c95488e4aa1d078ff5776b58f66bd29d628fa59adcb2047f4efd3ecb2bd41a71"}, + {file = "orjson-3.9.12-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d6ce2062c4af43b92b0221ed4f445632c6bf4213f8a7da5396a122931377acd9"}, + {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:950951799967558c214cd6cceb7ceceed6f81d2c3c4135ee4a2c9c69f58aa225"}, + {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2dfaf71499d6fd4153f5c86eebb68e3ec1bf95851b030a4b55c7637a37bbdee4"}, + {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:659a8d7279e46c97661839035a1a218b61957316bf0202674e944ac5cfe7ed83"}, + {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af17fa87bccad0b7f6fd8ac8f9cbc9ee656b4552783b10b97a071337616db3e4"}, + {file = "orjson-3.9.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd52dec9eddf4c8c74392f3fd52fa137b5f2e2bed1d9ae958d879de5f7d7cded"}, + {file = "orjson-3.9.12-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:640e2b5d8e36b970202cfd0799d11a9a4ab46cf9212332cd642101ec952df7c8"}, + {file = "orjson-3.9.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:daa438bd8024e03bcea2c5a92cd719a663a58e223fba967296b6ab9992259dbf"}, + {file = "orjson-3.9.12-cp312-none-win_amd64.whl", hash = "sha256:1bb8f657c39ecdb924d02e809f992c9aafeb1ad70127d53fb573a6a6ab59d549"}, + {file = "orjson-3.9.12-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:f4098c7674901402c86ba6045a551a2ee345f9f7ed54eeffc7d86d155c8427e5"}, + {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5586a533998267458fad3a457d6f3cdbddbcce696c916599fa8e2a10a89b24d3"}, + {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:54071b7398cd3f90e4bb61df46705ee96cb5e33e53fc0b2f47dbd9b000e238e1"}, + {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:67426651faa671b40443ea6f03065f9c8e22272b62fa23238b3efdacd301df31"}, + {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4a0cd56e8ee56b203abae7d482ac0d233dbfb436bb2e2d5cbcb539fe1200a312"}, + {file = "orjson-3.9.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a84a0c3d4841a42e2571b1c1ead20a83e2792644c5827a606c50fc8af7ca4bee"}, + {file = "orjson-3.9.12-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:09d60450cda3fa6c8ed17770c3a88473a16460cd0ff2ba74ef0df663b6fd3bb8"}, + {file = "orjson-3.9.12-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bc82a4db9934a78ade211cf2e07161e4f068a461c1796465d10069cb50b32a80"}, + {file = "orjson-3.9.12-cp38-none-win32.whl", hash = "sha256:61563d5d3b0019804d782137a4f32c72dc44c84e7d078b89d2d2a1adbaa47b52"}, + {file = "orjson-3.9.12-cp38-none-win_amd64.whl", hash = "sha256:410f24309fbbaa2fab776e3212a81b96a1ec6037259359a32ea79fbccfcf76aa"}, + {file = "orjson-3.9.12-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e773f251258dd82795fd5daeac081d00b97bacf1548e44e71245543374874bcf"}, + {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b159baecfda51c840a619948c25817d37733a4d9877fea96590ef8606468b362"}, + {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:975e72e81a249174840d5a8df977d067b0183ef1560a32998be340f7e195c730"}, + {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:06e42e899dde61eb1851a9fad7f1a21b8e4be063438399b63c07839b57668f6c"}, + {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c157e999e5694475a5515942aebeed6e43f7a1ed52267c1c93dcfde7d78d421"}, + {file = "orjson-3.9.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dde1bc7c035f2d03aa49dc8642d9c6c9b1a81f2470e02055e76ed8853cfae0c3"}, + {file = "orjson-3.9.12-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b0e9d73cdbdad76a53a48f563447e0e1ce34bcecef4614eb4b146383e6e7d8c9"}, + {file = "orjson-3.9.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:96e44b21fe407b8ed48afbb3721f3c8c8ce17e345fbe232bd4651ace7317782d"}, + {file = "orjson-3.9.12-cp39-none-win32.whl", hash = "sha256:cbd0f3555205bf2a60f8812133f2452d498dbefa14423ba90fe89f32276f7abf"}, + {file = "orjson-3.9.12-cp39-none-win_amd64.whl", hash = "sha256:03ea7ee7e992532c2f4a06edd7ee1553f0644790553a118e003e3c405add41fa"}, + {file = "orjson-3.9.12.tar.gz", hash = "sha256:da908d23a3b3243632b523344403b128722a5f45e278a8343c2bb67538dff0e4"}, +] + [[package]] name = "overrides" version = "7.4.0" @@ -8110,6 +8225,7 @@ googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth- huggingface-hub = ["huggingface_hub"] llama2 = ["replicate"] milvus = ["pymilvus"] +mistralai = ["langchain-mistralai"] modal = ["modal"] mysql = ["mysql-connector-python"] opensearch = ["opensearch-py"] @@ -8130,4 +8246,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "02bd85e14374a9dc9b59523b8fb4baea7068251976ba7f87722cac94a9974ccc" +content-hash = "cb0da55af7c61300bb321770ed319c900b6b3ba3865421d63eb9120beb73d06c" diff --git a/pyproject.toml b/pyproject.toml index 01e9bb5d..b7d9059b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,7 @@ google-auth-oauthlib = { version = "^1.2.0", optional = true } google-auth = { version = "^2.25.2", optional = true } google-auth-httplib2 = { version = "^0.2.0", optional = true } google-api-core = { version = "^2.15.0", optional = true } +langchain-mistralai = { version = "^0.0.3", optional = true } [tool.poetry.group.dev.dependencies] black = "^23.3.0" @@ -214,6 +215,7 @@ rss_feed = [ google = ["google-generativeai"] modal = ["modal"] dropbox = ["dropbox"] +mistralai = ["langchain-mistralai"] [tool.poetry.group.docs.dependencies] diff --git a/tests/llm/test_mistralai.py b/tests/llm/test_mistralai.py new file mode 100644 index 00000000..c99999c8 --- /dev/null +++ b/tests/llm/test_mistralai.py @@ -0,0 +1,60 @@ +import pytest + +from embedchain.config import BaseLlmConfig +from embedchain.llm.mistralai import MistralAILlm + + +@pytest.fixture +def mistralai_llm_config(monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key") + yield BaseLlmConfig(model="mistral-tiny", max_tokens=100, temperature=0.7, top_p=0.5, stream=False) + monkeypatch.delenv("MISTRAL_API_KEY", raising=False) + + +def test_mistralai_llm_init_missing_api_key(monkeypatch): + monkeypatch.delenv("MISTRAL_API_KEY", raising=False) + with pytest.raises(ValueError, match="Please set the MISTRAL_API_KEY environment variable."): + MistralAILlm() + + +def test_mistralai_llm_init(monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key") + llm = MistralAILlm() + assert llm is not None + + +def test_get_llm_model_answer(monkeypatch, mistralai_llm_config): + def mock_get_answer(prompt, config): + return "Generated Text" + + monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer) + llm = MistralAILlm(config=mistralai_llm_config) + result = llm.get_llm_model_answer("test prompt") + + assert result == "Generated Text" + + +def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config): + mistralai_llm_config.system_prompt = "Test system prompt" + monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text") + llm = MistralAILlm(config=mistralai_llm_config) + result = llm.get_llm_model_answer("test prompt") + + assert result == "Generated Text" + + +def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config): + monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text") + llm = MistralAILlm(config=mistralai_llm_config) + result = llm.get_llm_model_answer("") + + assert result == "Generated Text" + + +def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config): + mistralai_llm_config.system_prompt = None + monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text") + llm = MistralAILlm(config=mistralai_llm_config) + result = llm.get_llm_model_answer("test prompt") + + assert result == "Generated Text"