From 8c91b75b98c3d7a749bc1476ef7bd58d46a1c3c8 Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Thu, 27 Jul 2023 00:33:32 -0700 Subject: [PATCH] [Feature]: Add support for azure openai model (#372) --- docs/advanced/app_types.mdx | 2 ++ embedchain/apps/CustomApp.py | 24 ++++++++++++++++++++++++ embedchain/models/Providers.py | 1 + tests/vectordb/test_chroma_db.py | 3 ++- 4 files changed, 29 insertions(+), 1 deletion(-) diff --git a/docs/advanced/app_types.mdx b/docs/advanced/app_types.mdx index 1c6895ad..7a75596d 100644 --- a/docs/advanced/app_types.mdx +++ b/docs/advanced/app_types.mdx @@ -85,11 +85,13 @@ app = CustomApp(config) - ANTHPROPIC - VERTEX_AI - GPT4ALL + - AZURE_OPENAI - Following embedding functions are available for an embedding function - OPENAI - HUGGING_FACE - VERTEX_AI - GPT4ALL + - AZURE_OPENAI ### PersonApp diff --git a/embedchain/apps/CustomApp.py b/embedchain/apps/CustomApp.py index 69062c15..f13328b1 100644 --- a/embedchain/apps/CustomApp.py +++ b/embedchain/apps/CustomApp.py @@ -64,6 +64,9 @@ class CustomApp(EmbedChain): if self.provider == Providers.GPT4ALL: return self.open_source_app._get_gpt4all_answer(prompt, config) + if self.provider == Providers.AZURE_OPENAI: + return CustomApp._get_azure_openai_answer(prompt, config) + except ImportError as e: raise ImportError(e.msg) from None @@ -113,6 +116,27 @@ class CustomApp(EmbedChain): return chat(messages).content + @staticmethod + def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str: + from langchain.chat_models import AzureChatOpenAI + + logging.info(vars(config)) + + chat = AzureChatOpenAI( + deployment_name="td2", + model_name=config.model or "text-davinci-002", + temperature=config.temperature, + max_tokens=config.max_tokens, + streaming=config.stream, + ) + + if config.top_p and config.top_p != 1: + logging.warning("Config option `top_p` is not supported by this model.") + + messages = CustomApp._get_messages(prompt) + + return chat(messages).content + @staticmethod def _get_messages(prompt: str) -> List[BaseMessage]: from langchain.schema import HumanMessage, SystemMessage diff --git a/embedchain/models/Providers.py b/embedchain/models/Providers.py index 998697fe..bd019cff 100644 --- a/embedchain/models/Providers.py +++ b/embedchain/models/Providers.py @@ -6,3 +6,4 @@ class Providers(Enum): ANTHROPHIC = "ANTHPROPIC" VERTEX_AI = "VERTEX_AI" GPT4ALL = "GPT4ALL" + AZURE_OPENAI = "AZURE_OPENAI" diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index 46d4c7c4..510aff1b 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -3,10 +3,11 @@ import unittest from unittest.mock import patch +from chromadb.config import Settings + from embedchain import App from embedchain.config import AppConfig from embedchain.vectordb.chroma_db import ChromaDB, chromadb -from chromadb.config import Settings class TestChromaDbHosts(unittest.TestCase):