fix: Pass deployment name as param for azure api (#406)

This commit is contained in:
Taranjeet Singh
2023-08-08 23:55:26 -07:00
committed by GitHub
parent 030e3521a9
commit 1f0f0c93b7
4 changed files with 20 additions and 5 deletions

View File

@@ -118,9 +118,13 @@ class CustomApp(EmbedChain):
def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
from langchain.chat_models import AzureChatOpenAI
if not config.deployment_name:
raise ValueError("Deployment name must be provided for Azure OpenAI")
chat = AzureChatOpenAI(
deployment_name="td2",
model_name=config.model or "text-davinci-002",
deployment_name=config.deployment_name,
openai_api_version="2023-05-15",
model_name=config.model or "gpt-3.5-turbo",
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=config.stream,

View File

@@ -33,6 +33,7 @@ class ChatConfig(QueryConfig):
max_tokens=None,
top_p=None,
stream: bool = False,
deployment_name=None,
):
"""
Initializes the ChatConfig instance.
@@ -68,6 +69,7 @@ class ChatConfig(QueryConfig):
top_p=top_p,
history=[0],
stream=stream,
deployment_name=deployment_name,
)
def set_history(self, history):

View File

@@ -62,6 +62,7 @@ class QueryConfig(BaseConfig):
top_p=None,
history=None,
stream: bool = False,
deployment_name=None,
):
"""
Initializes the QueryConfig instance.
@@ -106,6 +107,7 @@ class QueryConfig(BaseConfig):
self.max_tokens = max_tokens if max_tokens else 1000
self.model = model
self.top_p = top_p if top_p else 1
self.deployment_name = deployment_name
if self.validate_template(template):
self.template = template

View File

@@ -27,6 +27,7 @@ class CustomAppConfig(BaseAppConfig):
provider: Providers = None,
model=None,
open_source_app_config=None,
deployment_name=None,
):
"""
:param log_level: Optional. (String) Debug level
@@ -49,7 +50,10 @@ class CustomAppConfig(BaseAppConfig):
super().__init__(
log_level=log_level,
embedding_fn=CustomAppConfig.embedding_function(embedding_function=embedding_fn, model=embedding_fn_model),
embedding_fn=CustomAppConfig.embedding_function(
embedding_function=embedding_fn, model=embedding_fn_model,
deployment_name=deployment_name
),
db=db,
host=host,
port=port,
@@ -68,7 +72,7 @@ class CustomAppConfig(BaseAppConfig):
return embed_function
@staticmethod
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None):
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None):
if not isinstance(embedding_function, EmbeddingFunctions):
raise ValueError(
f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501
@@ -80,7 +84,10 @@ class CustomAppConfig(BaseAppConfig):
if model:
embeddings = OpenAIEmbeddings(model=model)
else:
embeddings = OpenAIEmbeddings()
if deployment_name:
embeddings = OpenAIEmbeddings(deployment=deployment_name)
else:
embeddings = OpenAIEmbeddings()
return CustomAppConfig.langchain_default_concept(embeddings)
elif embedding_function == EmbeddingFunctions.HUGGING_FACE: