[Feature]: Add support for azure openai model (#372)
This commit is contained in:
@@ -85,11 +85,13 @@ app = CustomApp(config)
|
||||
- ANTHPROPIC
|
||||
- VERTEX_AI
|
||||
- GPT4ALL
|
||||
- AZURE_OPENAI
|
||||
- Following embedding functions are available for an embedding function
|
||||
- OPENAI
|
||||
- HUGGING_FACE
|
||||
- VERTEX_AI
|
||||
- GPT4ALL
|
||||
- AZURE_OPENAI
|
||||
|
||||
|
||||
### PersonApp
|
||||
|
||||
@@ -64,6 +64,9 @@ class CustomApp(EmbedChain):
|
||||
if self.provider == Providers.GPT4ALL:
|
||||
return self.open_source_app._get_gpt4all_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.AZURE_OPENAI:
|
||||
return CustomApp._get_azure_openai_answer(prompt, config)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(e.msg) from None
|
||||
|
||||
@@ -113,6 +116,27 @@ class CustomApp(EmbedChain):
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
|
||||
logging.info(vars(config))
|
||||
|
||||
chat = AzureChatOpenAI(
|
||||
deployment_name="td2",
|
||||
model_name=config.model or "text-davinci-002",
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
streaming=config.stream,
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_messages(prompt: str) -> List[BaseMessage]:
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
@@ -6,3 +6,4 @@ class Providers(Enum):
|
||||
ANTHROPHIC = "ANTHPROPIC"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
GPT4ALL = "GPT4ALL"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from chromadb.config import Settings
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.vectordb.chroma_db import ChromaDB, chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
|
||||
class TestChromaDbHosts(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user