[Feature] Add support for Mistral API (#1194)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
46
embedchain/embedder/mistralai.py
Normal file
46
embedchain/embedder/mistralai.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
|
||||
class MistralAIEmbeddingFunction(EmbeddingFunction):
|
||||
def __init__(self, config: BaseEmbedderConfig) -> None:
|
||||
super().__init__()
|
||||
try:
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for MistralAI are not installed."
|
||||
'Please install with `pip install --upgrade "embedchain[mistralai]"`'
|
||||
) from None
|
||||
self.config = config
|
||||
api_key = self.config.api_key or os.getenv("MISTRAL_API_KEY")
|
||||
self.client = MistralAIEmbeddings(mistral_api_key=api_key)
|
||||
self.client.model = self.config.model
|
||||
|
||||
def __call__(self, input: Union[list[str], str]) -> Embeddings:
|
||||
if isinstance(input, str):
|
||||
input_ = [input]
|
||||
else:
|
||||
input_ = input
|
||||
response = self.client.embed_documents(input_)
|
||||
return response
|
||||
|
||||
|
||||
class MistralAIEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if self.config.model is None:
|
||||
self.config.model = "mistral-embed"
|
||||
|
||||
embedding_fn = MistralAIEmbeddingFunction(config=self.config)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = self.config.vector_dimension or VectorDimensions.MISTRAL_AI.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
@@ -21,6 +21,7 @@ class LlmFactory:
|
||||
"openai": "embedchain.llm.openai.OpenAILlm",
|
||||
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
|
||||
"google": "embedchain.llm.google.GoogleLlm",
|
||||
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
|
||||
@@ -50,6 +51,7 @@ class EmbedderFactory:
|
||||
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
|
||||
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
|
||||
"google": "embedchain.embedder.google.GoogleAIEmbedder",
|
||||
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||
|
||||
52
embedchain/llm/mistralai.py
Normal file
52
embedchain/llm/mistralai.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class MistralAILlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return MistralAILlm._get_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig):
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for MistralAI are not installed."
|
||||
'Please install with `pip install --upgrade "embedchain[mistralai]"`'
|
||||
) from None
|
||||
|
||||
api_key = config.api_key or os.getenv("MISTRAL_API_KEY")
|
||||
client = ChatMistralAI(mistral_api_key=api_key)
|
||||
messages = []
|
||||
if config.system_prompt:
|
||||
messages.append(SystemMessage(content=config.system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
kwargs = {
|
||||
"model": config.model or "mistral-tiny",
|
||||
"temperature": config.temperature,
|
||||
"max_tokens": config.max_tokens,
|
||||
"top_p": config.top_p,
|
||||
}
|
||||
|
||||
# TODO: Add support for streaming
|
||||
if config.stream:
|
||||
answer = ""
|
||||
for chunk in client.stream(**kwargs, input=messages):
|
||||
answer += chunk.content
|
||||
return answer
|
||||
else:
|
||||
response = client.invoke(**kwargs, input=messages)
|
||||
answer = response.content
|
||||
return answer
|
||||
@@ -8,3 +8,4 @@ class VectorDimensions(Enum):
|
||||
VERTEX_AI = 768
|
||||
HUGGING_FACE = 384
|
||||
GOOGLE_AI = 768
|
||||
MISTRAL_AI = 1024
|
||||
|
||||
@@ -406,6 +406,7 @@ def validate_config(config_data):
|
||||
"llama2",
|
||||
"vertexai",
|
||||
"google",
|
||||
"mistralai",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): str,
|
||||
@@ -431,7 +432,15 @@ def validate_config(config_data):
|
||||
Optional("config"): object, # TODO: add particular config schema for each provider
|
||||
},
|
||||
Optional("embedder"): {
|
||||
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
|
||||
Optional("provider"): Or(
|
||||
"openai",
|
||||
"gpt4all",
|
||||
"huggingface",
|
||||
"vertexai",
|
||||
"azure_openai",
|
||||
"google",
|
||||
"mistralai",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): Optional(str),
|
||||
Optional("deployment_name"): Optional(str),
|
||||
@@ -442,7 +451,15 @@ def validate_config(config_data):
|
||||
},
|
||||
},
|
||||
Optional("embedding_model"): {
|
||||
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
|
||||
Optional("provider"): Or(
|
||||
"openai",
|
||||
"gpt4all",
|
||||
"huggingface",
|
||||
"vertexai",
|
||||
"azure_openai",
|
||||
"google",
|
||||
"mistralai",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): str,
|
||||
Optional("deployment_name"): str,
|
||||
|
||||
Reference in New Issue
Block a user