[Feature]: Add support for azure openai model (#372)

This commit is contained in:
Deshraj Yadav
2023-07-27 00:33:32 -07:00
committed by GitHub
parent 55bfd7cafe
commit 8c91b75b98
4 changed files with 29 additions and 1 deletions

View File

@@ -85,11 +85,13 @@ app = CustomApp(config)
- ANTHPROPIC - ANTHPROPIC
- VERTEX_AI - VERTEX_AI
- GPT4ALL - GPT4ALL
- AZURE_OPENAI
- Following embedding functions are available for an embedding function - Following embedding functions are available for an embedding function
- OPENAI - OPENAI
- HUGGING_FACE - HUGGING_FACE
- VERTEX_AI - VERTEX_AI
- GPT4ALL - GPT4ALL
- AZURE_OPENAI
### PersonApp ### PersonApp

View File

@@ -64,6 +64,9 @@ class CustomApp(EmbedChain):
if self.provider == Providers.GPT4ALL: if self.provider == Providers.GPT4ALL:
return self.open_source_app._get_gpt4all_answer(prompt, config) return self.open_source_app._get_gpt4all_answer(prompt, config)
if self.provider == Providers.AZURE_OPENAI:
return CustomApp._get_azure_openai_answer(prompt, config)
except ImportError as e: except ImportError as e:
raise ImportError(e.msg) from None raise ImportError(e.msg) from None
@@ -113,6 +116,27 @@ class CustomApp(EmbedChain):
return chat(messages).content return chat(messages).content
@staticmethod
def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
from langchain.chat_models import AzureChatOpenAI
logging.info(vars(config))
chat = AzureChatOpenAI(
deployment_name="td2",
model_name=config.model or "text-davinci-002",
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=config.stream,
)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = CustomApp._get_messages(prompt)
return chat(messages).content
@staticmethod @staticmethod
def _get_messages(prompt: str) -> List[BaseMessage]: def _get_messages(prompt: str) -> List[BaseMessage]:
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage

View File

@@ -6,3 +6,4 @@ class Providers(Enum):
ANTHROPHIC = "ANTHPROPIC" ANTHROPHIC = "ANTHPROPIC"
VERTEX_AI = "VERTEX_AI" VERTEX_AI = "VERTEX_AI"
GPT4ALL = "GPT4ALL" GPT4ALL = "GPT4ALL"
AZURE_OPENAI = "AZURE_OPENAI"

View File

@@ -3,10 +3,11 @@
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from chromadb.config import Settings
from embedchain import App from embedchain import App
from embedchain.config import AppConfig from embedchain.config import AppConfig
from embedchain.vectordb.chroma_db import ChromaDB, chromadb from embedchain.vectordb.chroma_db import ChromaDB, chromadb
from chromadb.config import Settings
class TestChromaDbHosts(unittest.TestCase): class TestChromaDbHosts(unittest.TestCase):