Clarifai : Added Clarifai as LLM and embedding model provider. (#1311)

Co-authored-by: Deshraj Yadav <deshraj@gatech.edu>
This commit is contained in:
mogith-pn
2024-06-17 21:18:18 +05:30
committed by GitHub
parent 4547d870af
commit 5acaae5f56
12 changed files with 579 additions and 2 deletions

View File

@@ -0,0 +1,52 @@
import os
from typing import Optional, Union
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from chromadb import EmbeddingFunction, Embeddings
class ClarifaiEmbeddingFunction(EmbeddingFunction):
def __init__(self, config: BaseEmbedderConfig) -> None:
super().__init__()
try:
from clarifai.client.model import Model
from clarifai.client.input import Inputs
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for ClarifaiEmbeddingFunction are not installed."
'Please install with `pip install --upgrade "embedchain[clarifai]"`'
) from None
self.config = config
self.api_key = config.api_key or os.getenv("CLARIFAI_PAT")
self.model = config.model
self.model_obj = Model(url=self.model, pat=self.api_key)
self.input_obj = Inputs(pat=self.api_key)
def __call__(self, input: Union[str, list[str]]) -> Embeddings:
if isinstance(input, str):
input = [input]
batch_size = 32
embeddings = []
try:
for i in range(0, len(input), batch_size):
batch = input[i : i + batch_size]
input_batch = [
self.input_obj.get_text_input(input_id=str(id), raw_text=inp) for id, inp in enumerate(batch)
]
response = self.model_obj.predict(input_batch)
embeddings.extend([list(output.data.embeddings[0].vector) for output in response.outputs])
except Exception as e:
print(f"Predict failed, exception: {e}")
return embeddings
class ClarifaiEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
embedding_func = ClarifaiEmbeddingFunction(config=self.config)
self.set_embedding_fn(embedding_fn=embedding_func)

View File

@@ -23,6 +23,7 @@ class LlmFactory:
"google": "embedchain.llm.google.GoogleLlm",
"aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm",
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
"clarifai": "embedchain.llm.clarifai.ClarifaiLlm",
"groq": "embedchain.llm.groq.GroqLlm",
"nvidia": "embedchain.llm.nvidia.NvidiaLlm",
"vllm": "embedchain.llm.vllm.VLLM",
@@ -56,6 +57,7 @@ class EmbedderFactory:
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
"google": "embedchain.embedder.google.GoogleAIEmbedder",
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
"clarifai": "embedchain.embedder.clarifai.ClarifaiEmbedder",
"nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
"cohere": "embedchain.embedder.cohere.CohereEmbedder",
"ollama": "embedchain.embedder.ollama.OllamaEmbedder",
@@ -65,6 +67,7 @@ class EmbedderFactory:
"google": "embedchain.config.embedder.google.GoogleAIEmbedderConfig",
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
"huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig",
"clarifai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig",
}

View File

@@ -0,0 +1,47 @@
import logging
import os
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@register_deserializable
class ClarifaiLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
if not self.config.api_key and "CLARIFAI_PAT" not in os.environ:
raise ValueError("Please set the CLARIFAI_PAT environment variable.")
def get_llm_model_answer(self, prompt):
return self._get_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
try:
from clarifai.client.model import Model
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for Clarifai are not installed."
'Please install with `pip install --upgrade "embedchain[clarifai]"`'
) from None
model_name = config.model
logging.info(f"Using clarifai LLM model: {model_name}")
api_key = config.api_key or os.getenv("CLARIFAI_PAT")
model = Model(url=model_name, pat=api_key)
params = config.model_kwargs
try:
(params := {}) if config.model_kwargs is None else config.model_kwargs
predict_response = model.predict_by_bytes(
bytes(prompt, "utf-8"),
input_type="text",
inference_params=params,
)
text = predict_response.outputs[0].data.text.raw
return text
except Exception as e:
logging.error(f"Predict failed, exception: {e}")

View File

@@ -414,6 +414,7 @@ def validate_config(config_data):
"google",
"aws_bedrock",
"mistralai",
"clarifai",
"vllm",
"groq",
"nvidia",
@@ -458,6 +459,7 @@ def validate_config(config_data):
"azure_openai",
"google",
"mistralai",
"clarifai",
"nvidia",
"ollama",
"cohere",
@@ -482,6 +484,7 @@ def validate_config(config_data):
"azure_openai",
"google",
"mistralai",
"clarifai",
"nvidia",
"ollama",
),

View File

@@ -251,4 +251,4 @@ class QdrantDB(BaseVectorDB):
def delete(self, where: dict):
db_filter = self._generate_query(where)
self.client.delete(collection_name=self.collection_name, points_selector=db_filter)
self.client.delete(collection_name=self.collection_name, points_selector=db_filter)