diff --git a/embedchain/apps/CustomApp.py b/embedchain/apps/CustomApp.py index e0ef9fc1..6b3cf30e 100644 --- a/embedchain/apps/CustomApp.py +++ b/embedchain/apps/CustomApp.py @@ -118,9 +118,13 @@ class CustomApp(EmbedChain): def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str: from langchain.chat_models import AzureChatOpenAI + if not config.deployment_name: + raise ValueError("Deployment name must be provided for Azure OpenAI") + chat = AzureChatOpenAI( - deployment_name="td2", - model_name=config.model or "text-davinci-002", + deployment_name=config.deployment_name, + openai_api_version="2023-05-15", + model_name=config.model or "gpt-3.5-turbo", temperature=config.temperature, max_tokens=config.max_tokens, streaming=config.stream, diff --git a/embedchain/config/ChatConfig.py b/embedchain/config/ChatConfig.py index 1f3a614d..0149fd3a 100644 --- a/embedchain/config/ChatConfig.py +++ b/embedchain/config/ChatConfig.py @@ -33,6 +33,7 @@ class ChatConfig(QueryConfig): max_tokens=None, top_p=None, stream: bool = False, + deployment_name=None, ): """ Initializes the ChatConfig instance. @@ -68,6 +69,7 @@ class ChatConfig(QueryConfig): top_p=top_p, history=[0], stream=stream, + deployment_name=deployment_name, ) def set_history(self, history): diff --git a/embedchain/config/QueryConfig.py b/embedchain/config/QueryConfig.py index b339b188..ef06f597 100644 --- a/embedchain/config/QueryConfig.py +++ b/embedchain/config/QueryConfig.py @@ -62,6 +62,7 @@ class QueryConfig(BaseConfig): top_p=None, history=None, stream: bool = False, + deployment_name=None, ): """ Initializes the QueryConfig instance. @@ -106,6 +107,7 @@ class QueryConfig(BaseConfig): self.max_tokens = max_tokens if max_tokens else 1000 self.model = model self.top_p = top_p if top_p else 1 + self.deployment_name = deployment_name if self.validate_template(template): self.template = template diff --git a/embedchain/config/apps/CustomAppConfig.py b/embedchain/config/apps/CustomAppConfig.py index 5649721e..22f8818e 100644 --- a/embedchain/config/apps/CustomAppConfig.py +++ b/embedchain/config/apps/CustomAppConfig.py @@ -27,6 +27,7 @@ class CustomAppConfig(BaseAppConfig): provider: Providers = None, model=None, open_source_app_config=None, + deployment_name=None, ): """ :param log_level: Optional. (String) Debug level @@ -49,7 +50,10 @@ class CustomAppConfig(BaseAppConfig): super().__init__( log_level=log_level, - embedding_fn=CustomAppConfig.embedding_function(embedding_function=embedding_fn, model=embedding_fn_model), + embedding_fn=CustomAppConfig.embedding_function( + embedding_function=embedding_fn, model=embedding_fn_model, + deployment_name=deployment_name + ), db=db, host=host, port=port, @@ -68,7 +72,7 @@ class CustomAppConfig(BaseAppConfig): return embed_function @staticmethod - def embedding_function(embedding_function: EmbeddingFunctions, model: str = None): + def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None): if not isinstance(embedding_function, EmbeddingFunctions): raise ValueError( f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501 @@ -80,7 +84,10 @@ class CustomAppConfig(BaseAppConfig): if model: embeddings = OpenAIEmbeddings(model=model) else: - embeddings = OpenAIEmbeddings() + if deployment_name: + embeddings = OpenAIEmbeddings(deployment=deployment_name) + else: + embeddings = OpenAIEmbeddings() return CustomAppConfig.langchain_default_concept(embeddings) elif embedding_function == EmbeddingFunctions.HUGGING_FACE: