From edaeb78ccf1bf046b83b0c6f8d4cc9973b061fac Mon Sep 17 00:00:00 2001 From: Vatsal Rathod Date: Wed, 26 Jun 2024 13:58:12 -0400 Subject: [PATCH] Refactor openai embedder (#1444) --- embedchain/embedder/azure_openai.py | 22 ++++++++++++++++++++++ embedchain/embedder/openai.py | 24 +++++++++--------------- embedchain/factory.py | 2 +- 3 files changed, 32 insertions(+), 16 deletions(-) create mode 100644 embedchain/embedder/azure_openai.py diff --git a/embedchain/embedder/azure_openai.py b/embedchain/embedder/azure_openai.py new file mode 100644 index 00000000..97441f84 --- /dev/null +++ b/embedchain/embedder/azure_openai.py @@ -0,0 +1,22 @@ +from typing import Optional + +from langchain_community.embeddings import AzureOpenAIEmbeddings + +from embedchain.config import BaseEmbedderConfig +from embedchain.embedder.base import BaseEmbedder +from embedchain.models import VectorDimensions + + +class AzureOpenAIEmbedder(BaseEmbedder): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config=config) + + if self.config.model is None: + self.config.model = "text-embedding-ada-002" + + embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name) + embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) + + self.set_embedding_fn(embedding_fn=embedding_fn) + vector_dimension = self.config.vector_dimension or VectorDimensions.OPENAI.value + self.set_vector_dimension(vector_dimension=vector_dimension) diff --git a/embedchain/embedder/openai.py b/embedchain/embedder/openai.py index ab361e27..fc2c7d63 100644 --- a/embedchain/embedder/openai.py +++ b/embedchain/embedder/openai.py @@ -2,7 +2,7 @@ import os from typing import Optional from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction -from langchain_openai.embeddings import AzureOpenAIEmbeddings + from embedchain.config import BaseEmbedderConfig from embedchain.embedder.base import BaseEmbedder @@ -19,20 +19,14 @@ class OpenAIEmbedder(BaseEmbedder): api_key = self.config.api_key or os.environ["OPENAI_API_KEY"] api_base = self.config.api_base or os.environ.get("OPENAI_API_BASE") - if self.config.deployment_name: - embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name) - embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) - else: - if api_key is None and os.getenv("OPENAI_ORGANIZATION") is None: - raise ValueError( - "OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided" - ) # noqa:E501 - embedding_fn = OpenAIEmbeddingFunction( - api_key=api_key, - api_base=api_base, - organization_id=os.getenv("OPENAI_ORGANIZATION"), - model_name=self.config.model, - ) + if api_key is None and os.getenv("OPENAI_ORGANIZATION") is None: + raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501 + embedding_fn = OpenAIEmbeddingFunction( + api_key=api_key, + api_base=api_base, + organization_id=os.getenv("OPENAI_ORGANIZATION"), + model_name=self.config.model, + ) self.set_embedding_fn(embedding_fn=embedding_fn) vector_dimension = self.config.vector_dimension or VectorDimensions.OPENAI.value self.set_vector_dimension(vector_dimension=vector_dimension) diff --git a/embedchain/factory.py b/embedchain/factory.py index 0ed7452e..db07e377 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -50,7 +50,7 @@ class LlmFactory: class EmbedderFactory: provider_to_class = { - "azure_openai": "embedchain.embedder.openai.OpenAIEmbedder", + "azure_openai": "embedchain.embedder.azure_openai.AzureOpenAIEmbedder", "gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder", "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder", "openai": "embedchain.embedder.openai.OpenAIEmbedder",