110 lines
5.2 KiB
Python
110 lines
5.2 KiB
Python
import importlib
|
|
|
|
|
|
def load_class(class_type):
|
|
module_path, class_name = class_type.rsplit(".", 1)
|
|
module = importlib.import_module(module_path)
|
|
return getattr(module, class_name)
|
|
|
|
|
|
class LlmFactory:
|
|
provider_to_class = {
|
|
"anthropic": "embedchain.llm.anthropic.AnthropicLlm",
|
|
"azure_openai": "embedchain.llm.azure_openai.AzureOpenAILlm",
|
|
"cohere": "embedchain.llm.cohere.CohereLlm",
|
|
"together": "embedchain.llm.together.TogetherLlm",
|
|
"gpt4all": "embedchain.llm.gpt4all.GPT4ALLLlm",
|
|
"ollama": "embedchain.llm.ollama.OllamaLlm",
|
|
"huggingface": "embedchain.llm.huggingface.HuggingFaceLlm",
|
|
"jina": "embedchain.llm.jina.JinaLlm",
|
|
"llama2": "embedchain.llm.llama2.Llama2Llm",
|
|
"openai": "embedchain.llm.openai.OpenAILlm",
|
|
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
|
|
"google": "embedchain.llm.google.GoogleLlm",
|
|
"aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm",
|
|
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
|
|
"groq": "embedchain.llm.groq.GroqLlm",
|
|
}
|
|
provider_to_config_class = {
|
|
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
|
|
"openai": "embedchain.config.llm.base.BaseLlmConfig",
|
|
"anthropic": "embedchain.config.llm.base.BaseLlmConfig",
|
|
}
|
|
|
|
@classmethod
|
|
def create(cls, provider_name, config_data):
|
|
class_type = cls.provider_to_class.get(provider_name)
|
|
# Default to embedchain base config if the provider is not in the config map
|
|
config_name = "embedchain" if provider_name not in cls.provider_to_config_class else provider_name
|
|
config_class_type = cls.provider_to_config_class.get(config_name)
|
|
if class_type:
|
|
llm_class = load_class(class_type)
|
|
llm_config_class = load_class(config_class_type)
|
|
return llm_class(config=llm_config_class(**config_data))
|
|
else:
|
|
raise ValueError(f"Unsupported Llm provider: {provider_name}")
|
|
|
|
|
|
class EmbedderFactory:
|
|
provider_to_class = {
|
|
"azure_openai": "embedchain.embedder.openai.OpenAIEmbedder",
|
|
"gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
|
|
"huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
|
|
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
|
|
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
|
|
"google": "embedchain.embedder.google.GoogleAIEmbedder",
|
|
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
|
|
}
|
|
provider_to_config_class = {
|
|
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
|
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
|
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
|
"google": "embedchain.config.embedder.google.GoogleAIEmbedderConfig",
|
|
"huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
|
}
|
|
|
|
@classmethod
|
|
def create(cls, provider_name, config_data):
|
|
class_type = cls.provider_to_class.get(provider_name)
|
|
# Default to openai config if the provider is not in the config map
|
|
config_name = "openai" if provider_name not in cls.provider_to_config_class else provider_name
|
|
config_class_type = cls.provider_to_config_class.get(config_name)
|
|
if class_type:
|
|
embedder_class = load_class(class_type)
|
|
embedder_config_class = load_class(config_class_type)
|
|
return embedder_class(config=embedder_config_class(**config_data))
|
|
else:
|
|
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
|
|
|
|
|
|
class VectorDBFactory:
|
|
provider_to_class = {
|
|
"chroma": "embedchain.vectordb.chroma.ChromaDB",
|
|
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
|
|
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
|
|
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
|
|
"qdrant": "embedchain.vectordb.qdrant.QdrantDB",
|
|
"weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
|
|
"zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
|
|
}
|
|
provider_to_config_class = {
|
|
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
|
|
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
|
|
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
|
|
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
|
|
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
|
|
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
|
|
"zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig",
|
|
}
|
|
|
|
@classmethod
|
|
def create(cls, provider_name, config_data):
|
|
class_type = cls.provider_to_class.get(provider_name)
|
|
config_class_type = cls.provider_to_config_class.get(provider_name)
|
|
if class_type:
|
|
embedder_class = load_class(class_type)
|
|
embedder_config_class = load_class(config_class_type)
|
|
return embedder_class(config=embedder_config_class(**config_data))
|
|
else:
|
|
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
|