[Feature]: Add support for azure openai model (#372)

This commit is contained in:
Deshraj Yadav
2023-07-27 00:33:32 -07:00
committed by GitHub
parent 55bfd7cafe
commit 8c91b75b98
4 changed files with 29 additions and 1 deletions

View File

@@ -85,11 +85,13 @@ app = CustomApp(config)
- ANTHPROPIC
- VERTEX_AI
- GPT4ALL
- AZURE_OPENAI
- Following embedding functions are available for an embedding function
- OPENAI
- HUGGING_FACE
- VERTEX_AI
- GPT4ALL
- AZURE_OPENAI
### PersonApp

View File

@@ -64,6 +64,9 @@ class CustomApp(EmbedChain):
if self.provider == Providers.GPT4ALL:
return self.open_source_app._get_gpt4all_answer(prompt, config)
if self.provider == Providers.AZURE_OPENAI:
return CustomApp._get_azure_openai_answer(prompt, config)
except ImportError as e:
raise ImportError(e.msg) from None
@@ -113,6 +116,27 @@ class CustomApp(EmbedChain):
return chat(messages).content
@staticmethod
def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
from langchain.chat_models import AzureChatOpenAI
logging.info(vars(config))
chat = AzureChatOpenAI(
deployment_name="td2",
model_name=config.model or "text-davinci-002",
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=config.stream,
)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = CustomApp._get_messages(prompt)
return chat(messages).content
@staticmethod
def _get_messages(prompt: str) -> List[BaseMessage]:
from langchain.schema import HumanMessage, SystemMessage

View File

@@ -6,3 +6,4 @@ class Providers(Enum):
ANTHROPHIC = "ANTHPROPIC"
VERTEX_AI = "VERTEX_AI"
GPT4ALL = "GPT4ALL"
AZURE_OPENAI = "AZURE_OPENAI"

View File

@@ -3,10 +3,11 @@
import unittest
from unittest.mock import patch
from chromadb.config import Settings
from embedchain import App
from embedchain.config import AppConfig
from embedchain.vectordb.chroma_db import ChromaDB, chromadb
from chromadb.config import Settings
class TestChromaDbHosts(unittest.TestCase):