[Feature] Add support for Mistral API (#1194)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2024-01-20 12:31:50 +05:30
committed by GitHub
parent 9afc6878c8
commit cb0499407e
9 changed files with 351 additions and 12 deletions

View File

@@ -0,0 +1,46 @@
import os
from typing import Optional, Union
from chromadb import EmbeddingFunction, Embeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class MistralAIEmbeddingFunction(EmbeddingFunction):
def __init__(self, config: BaseEmbedderConfig) -> None:
super().__init__()
try:
from langchain_mistralai import MistralAIEmbeddings
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for MistralAI are not installed."
'Please install with `pip install --upgrade "embedchain[mistralai]"`'
) from None
self.config = config
api_key = self.config.api_key or os.getenv("MISTRAL_API_KEY")
self.client = MistralAIEmbeddings(mistral_api_key=api_key)
self.client.model = self.config.model
def __call__(self, input: Union[list[str], str]) -> Embeddings:
if isinstance(input, str):
input_ = [input]
else:
input_ = input
response = self.client.embed_documents(input_)
return response
class MistralAIEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
if self.config.model is None:
self.config.model = "mistral-embed"
embedding_fn = MistralAIEmbeddingFunction(config=self.config)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = self.config.vector_dimension or VectorDimensions.MISTRAL_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -21,6 +21,7 @@ class LlmFactory:
"openai": "embedchain.llm.openai.OpenAILlm",
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
"google": "embedchain.llm.google.GoogleLlm",
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
}
provider_to_config_class = {
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
@@ -50,6 +51,7 @@ class EmbedderFactory:
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
"google": "embedchain.embedder.google.GoogleAIEmbedder",
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
}
provider_to_config_class = {
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",

View File

@@ -0,0 +1,52 @@
import os
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@register_deserializable
class MistralAILlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt):
return MistralAILlm._get_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig):
try:
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_mistralai.chat_models import ChatMistralAI
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for MistralAI are not installed."
'Please install with `pip install --upgrade "embedchain[mistralai]"`'
) from None
api_key = config.api_key or os.getenv("MISTRAL_API_KEY")
client = ChatMistralAI(mistral_api_key=api_key)
messages = []
if config.system_prompt:
messages.append(SystemMessage(content=config.system_prompt))
messages.append(HumanMessage(content=prompt))
kwargs = {
"model": config.model or "mistral-tiny",
"temperature": config.temperature,
"max_tokens": config.max_tokens,
"top_p": config.top_p,
}
# TODO: Add support for streaming
if config.stream:
answer = ""
for chunk in client.stream(**kwargs, input=messages):
answer += chunk.content
return answer
else:
response = client.invoke(**kwargs, input=messages)
answer = response.content
return answer

View File

@@ -8,3 +8,4 @@ class VectorDimensions(Enum):
VERTEX_AI = 768
HUGGING_FACE = 384
GOOGLE_AI = 768
MISTRAL_AI = 1024

View File

@@ -406,6 +406,7 @@ def validate_config(config_data):
"llama2",
"vertexai",
"google",
"mistralai",
),
Optional("config"): {
Optional("model"): str,
@@ -431,7 +432,15 @@ def validate_config(config_data):
Optional("config"): object, # TODO: add particular config schema for each provider
},
Optional("embedder"): {
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
Optional("provider"): Or(
"openai",
"gpt4all",
"huggingface",
"vertexai",
"azure_openai",
"google",
"mistralai",
),
Optional("config"): {
Optional("model"): Optional(str),
Optional("deployment_name"): Optional(str),
@@ -442,7 +451,15 @@ def validate_config(config_data):
},
},
Optional("embedding_model"): {
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
Optional("provider"): Or(
"openai",
"gpt4all",
"huggingface",
"vertexai",
"azure_openai",
"google",
"mistralai",
),
Optional("config"): {
Optional("model"): str,
Optional("deployment_name"): str,